| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- from __future__ import annotations
- from argparse import Namespace
- from pathlib import Path
- from typing import Any, Sequence
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .decoder_2d import StructureAwareDecoder2d
- from .layers_2d import Conv2dBN
- from .swinv2_fwta_encoder_2d import SwinV2FWTAEncoder2d
- class SegmentationHead2d(nn.Module):
- def __init__(self, in_channels: int, out_channels: int) -> None:
- super().__init__()
- self.block = nn.Sequential(
- Conv2dBN(in_channels, in_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
- )
- def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
- x = self.block(x)
- return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
- class BoundaryHead2d(nn.Module):
- def __init__(self, in_channels: int, out_channels: int = 1) -> None:
- super().__init__()
- self.block = nn.Sequential(
- Conv2dBN(in_channels, in_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
- )
- def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
- x = self.block(x)
- return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
- class GlobalTokenConditioning2d(nn.Module):
- """
- 使用 FWTA 更新后的全局前景 token 对解码特征做通道调制。
- """
- def __init__(self, token_channels: int, feature_channels: int) -> None:
- super().__init__()
- hidden_channels = max(feature_channels // 2, 32)
- self.gate = nn.Sequential(
- nn.LayerNorm(token_channels),
- nn.Linear(token_channels, hidden_channels),
- nn.GELU(),
- nn.Linear(hidden_channels, feature_channels),
- nn.Sigmoid(),
- )
- def forward(self, x: torch.Tensor, global_token: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- channel_gate = self.gate(global_token).unsqueeze(-1).unsqueeze(-1)
- return x * (1.0 + channel_gate), channel_gate
- class SegmentationNet2d(nn.Module):
- """
- 第一版超声分割主网络骨架。
- 当前职责:
- - 编码器输出多尺度特征和稳定性图
- - 结构感知解码器恢复分割特征
- - 同时输出分割图和边界图
- """
- def __init__(
- self,
- num_classes: int,
- model_name: str | None = None,
- config_path: str | Path | None = None,
- weight_path: str | Path | None = None,
- args: Namespace | None = None,
- *,
- decoder_channels: Sequence[int] | None = None,
- load_weights: bool = True,
- **encoder_kwargs: Any,
- ) -> None:
- super().__init__()
- self.encoder = SwinV2FWTAEncoder2d(
- model_name=model_name,
- config_path=config_path,
- weight_path=weight_path,
- args=args,
- load_weights=load_weights,
- **encoder_kwargs,
- )
- self.decoder = StructureAwareDecoder2d(
- encoder_channels=self.encoder.stage_channels,
- decoder_channels=decoder_channels,
- )
- self.global_conditioning = GlobalTokenConditioning2d(
- token_channels=self.encoder.stage_channels[-1],
- feature_channels=self.decoder.out_channels,
- )
- self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes)
- self.boundary_head = BoundaryHead2d(self.decoder.out_channels, out_channels=1)
- def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
- encoder_outputs = self.encoder(x)
- features = encoder_outputs["features"]
- stability_map = encoder_outputs["stability_map"]
- decoder_out, decoder_features = self.decoder(
- features=features,
- stability_map=stability_map,
- )
- conditioned_decoder_out, global_channel_gate = self.global_conditioning(
- decoder_out,
- encoder_outputs["global_token"],
- )
- output_size = x.shape[-2:]
- seg_logits = self.segmentation_head(conditioned_decoder_out, output_size=output_size)
- boundary_logits = self.boundary_head(conditioned_decoder_out, output_size=output_size)
- return {
- "seg_logits": seg_logits,
- "boundary_logits": boundary_logits,
- "stability_map": F.interpolate(
- stability_map, size=output_size, mode="bilinear", align_corners=False
- ),
- "encoder_features": features,
- "decoder_features": decoder_features,
- "conditioned_decoder_feature": conditioned_decoder_out,
- "deepest_feature": encoder_outputs["deepest_feature"],
- "global_token": encoder_outputs["global_token"],
- "global_channel_gate": global_channel_gate,
- "token_gate": encoder_outputs["token_gate"],
- }
|