from __future__ import annotations from argparse import Namespace from pathlib import Path from typing import Any, Sequence import torch import torch.nn as nn import torch.nn.functional as F from .decoder_2d import StructureAwareDecoder2d from .layers_2d import Conv2dBN from .swinv2_fwta_encoder_2d import SwinV2FWTAEncoder2d class SegmentationHead2d(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.block = nn.Sequential( Conv2dBN(in_channels, in_channels, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True), ) def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor: x = self.block(x) return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False) class BoundaryHead2d(nn.Module): def __init__(self, in_channels: int, out_channels: int = 1) -> None: super().__init__() self.block = nn.Sequential( Conv2dBN(in_channels, in_channels, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True), ) def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor: x = self.block(x) return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False) class GlobalTokenConditioning2d(nn.Module): """ 使用 FWTA 更新后的全局前景 token 对解码特征做通道调制。 """ def __init__(self, token_channels: int, feature_channels: int) -> None: super().__init__() hidden_channels = max(feature_channels // 2, 32) self.gate = nn.Sequential( nn.LayerNorm(token_channels), nn.Linear(token_channels, hidden_channels), nn.GELU(), nn.Linear(hidden_channels, feature_channels), nn.Sigmoid(), ) def forward(self, x: torch.Tensor, global_token: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: channel_gate = self.gate(global_token).unsqueeze(-1).unsqueeze(-1) return x * (1.0 + channel_gate), channel_gate class SegmentationNet2d(nn.Module): """ 第一版超声分割主网络骨架。 当前职责: - 编码器输出多尺度特征和稳定性图 - 结构感知解码器恢复分割特征 - 同时输出分割图和边界图 """ def __init__( self, num_classes: int, model_name: str | None = None, config_path: str | Path | None = None, weight_path: str | Path | None = None, args: Namespace | None = None, *, decoder_channels: Sequence[int] | None = None, load_weights: bool = True, **encoder_kwargs: Any, ) -> None: super().__init__() self.encoder = SwinV2FWTAEncoder2d( model_name=model_name, config_path=config_path, weight_path=weight_path, args=args, load_weights=load_weights, **encoder_kwargs, ) self.decoder = StructureAwareDecoder2d( encoder_channels=self.encoder.stage_channels, decoder_channels=decoder_channels, ) self.global_conditioning = GlobalTokenConditioning2d( token_channels=self.encoder.stage_channels[-1], feature_channels=self.decoder.out_channels, ) self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes) self.boundary_head = BoundaryHead2d(self.decoder.out_channels, out_channels=1) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]: encoder_outputs = self.encoder(x) features = encoder_outputs["features"] stability_map = encoder_outputs["stability_map"] decoder_out, decoder_features = self.decoder( features=features, stability_map=stability_map, ) conditioned_decoder_out, global_channel_gate = self.global_conditioning( decoder_out, encoder_outputs["global_token"], ) output_size = x.shape[-2:] seg_logits = self.segmentation_head(conditioned_decoder_out, output_size=output_size) boundary_logits = self.boundary_head(conditioned_decoder_out, output_size=output_size) return { "seg_logits": seg_logits, "boundary_logits": boundary_logits, "stability_map": F.interpolate( stability_map, size=output_size, mode="bilinear", align_corners=False ), "encoder_features": features, "decoder_features": decoder_features, "conditioned_decoder_feature": conditioned_decoder_out, "deepest_feature": encoder_outputs["deepest_feature"], "global_token": encoder_outputs["global_token"], "global_channel_gate": global_channel_gate, "token_gate": encoder_outputs["token_gate"], }