from __future__ import annotations from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from .layers_2d import Conv2dBN class DecodeRefineBlock2d(nn.Module): """ 对解码后的融合特征做轻量残差细化。 """ def __init__(self, channels: int) -> None: super().__init__() self.refine = nn.Sequential( Conv2dBN(channels, channels, 3, 1, 1), nn.ReLU(inplace=True), Conv2dBN(channels, channels, 3, 1, 1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.refine(x) class SegmentationDecodeBlock2d(nn.Module): """ 单层解码块:上采样高层特征,与 skip 特征融合后细化。 """ def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None: super().__init__() self.high_proj = nn.Sequential( Conv2dBN(in_channels, out_channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.skip_proj = nn.Sequential( Conv2dBN(skip_channels, out_channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.fuse = nn.Sequential( Conv2dBN(out_channels * 2, out_channels, 3, 1, 1), nn.ReLU(inplace=True), Conv2dBN(out_channels, out_channels, 3, 1, 1), nn.ReLU(inplace=True), ) self.refine = DecodeRefineBlock2d(out_channels) def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False) x = self.high_proj(x) skip = self.skip_proj(skip) x = self.fuse(torch.cat([x, skip], dim=1)) return self.refine(x) class SegmentationDecoder2d(nn.Module): """ 纯净的多尺度解码器骨架。 输入特征默认按从浅到深排列,最后一个特征视为最深层输入。 """ def __init__( self, encoder_channels: Sequence[int], decoder_channels: Sequence[int] | None = None, ) -> None: super().__init__() if len(encoder_channels) < 2: raise ValueError("encoder_channels must contain at least two stages.") self.encoder_channels = list(encoder_channels) if decoder_channels is None: decoder_channels = list(reversed(self.encoder_channels[:-1])) if len(decoder_channels) != len(self.encoder_channels) - 1: raise ValueError("decoder_channels length must match len(encoder_channels) - 1.") in_channels = self.encoder_channels[-1] skip_channels = list(reversed(self.encoder_channels[:-1])) blocks = [] for skip_ch, out_ch in zip(skip_channels, decoder_channels): blocks.append(SegmentationDecodeBlock2d(in_channels, skip_ch, out_ch)) in_channels = out_ch self.blocks = nn.ModuleList(blocks) self.out_channels = in_channels def forward(self, features: Sequence[torch.Tensor]) -> tuple[torch.Tensor, list[torch.Tensor]]: if len(features) != len(self.encoder_channels): raise ValueError( f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}" ) x = features[-1] skips = list(reversed(features[:-1])) decoder_features = [] for block, skip in zip(self.blocks, skips): x = block(x, skip) decoder_features.append(x) return x, decoder_features