from __future__ import annotations from argparse import Namespace from pathlib import Path from typing import Any, Sequence import torch import torch.nn as nn import torch.nn.functional as F from .decoder_2d import SegmentationDecoder2d from .layers_2d import Conv2dBN from .swinv2_encoder_2d import SwinV2Encoder2d class SegmentationHead2d(nn.Module): def __init__(self, in_channels: int, out_channels: int) -> None: super().__init__() self.block = nn.Sequential( Conv2dBN(in_channels, in_channels, 3, 1, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True), ) def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor: x = self.block(x) return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False) class SegmentationModel2d(nn.Module): """ 简化后的超声分割网络。 """ def __init__( self, num_classes: int, model_name: str | None = None, config_path: str | Path | None = None, weight_path: str | Path | None = None, args: Namespace | None = None, *, decoder_channels: Sequence[int] | None = None, load_weights: bool = True, **encoder_kwargs: Any, ) -> None: super().__init__() self.encoder = SwinV2Encoder2d( model_name=model_name, config_path=config_path, weight_path=weight_path, args=args, load_weights=load_weights, **encoder_kwargs, ) self.decoder = SegmentationDecoder2d( encoder_channels=self.encoder.stage_channels, decoder_channels=decoder_channels, ) self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: features = self.encoder(x)["features"] decoder_out, _ = self.decoder(features) output_size = x.shape[-2:] seg_logits = self.segmentation_head(decoder_out, output_size=output_size) return {"seg_logits": seg_logits}