from __future__ import annotations from typing import Sequence import torch import torch.nn as nn import torch.nn.functional as F from .layers_2d import Conv2dBN class BoundaryRefineBlock2d(nn.Module): """ 使用边界提示和稳定性图对解码特征做轻量细化。 """ def __init__(self, channels: int) -> None: super().__init__() self.refine = nn.Sequential( Conv2dBN(channels, channels, 3, 1, 1), nn.ReLU(inplace=True), Conv2dBN(channels, channels, 3, 1, 1), ) def forward( self, x: torch.Tensor, boundary_hint: torch.Tensor | None = None, stability_map: torch.Tensor | None = None, ) -> torch.Tensor: modulator = 1.0 if stability_map is not None: stability_map = F.interpolate( stability_map, size=x.shape[-2:], mode="bilinear", align_corners=False ) modulator = modulator + stability_map if boundary_hint is not None: boundary_hint = F.interpolate( boundary_hint, size=x.shape[-2:], mode="bilinear", align_corners=False ) modulator = modulator + boundary_hint return x + self.refine(x * modulator) class StructureAwareDecodeBlock2d(nn.Module): """ 单层结构感知解码块。 """ def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None: super().__init__() self.high_proj = nn.Sequential( Conv2dBN(in_channels, out_channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.skip_proj = nn.Sequential( Conv2dBN(skip_channels, out_channels, 1, 1, 0), nn.ReLU(inplace=True), ) self.fuse = nn.Sequential( Conv2dBN(out_channels * 2, out_channels, 3, 1, 1), nn.ReLU(inplace=True), Conv2dBN(out_channels, out_channels, 3, 1, 1), nn.ReLU(inplace=True), ) self.refine = BoundaryRefineBlock2d(out_channels) def forward( self, x: torch.Tensor, skip: torch.Tensor, stability_map: torch.Tensor | None = None, boundary_hint: torch.Tensor | None = None, ) -> torch.Tensor: x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False) x = self.high_proj(x) skip = self.skip_proj(skip) x = self.fuse(torch.cat([x, skip], dim=1)) x = self.refine(x, boundary_hint=boundary_hint, stability_map=stability_map) return x class StructureAwareDecoder2d(nn.Module): """ 第一版结构感知解码器骨架。 输入特征默认按从浅到深排列,最后一个特征视为最深层输入。 """ def __init__(self, encoder_channels: Sequence[int], decoder_channels: Sequence[int] | None = None) -> None: super().__init__() if len(encoder_channels) < 2: raise ValueError("encoder_channels must contain at least two stages.") self.encoder_channels = list(encoder_channels) if decoder_channels is None: decoder_channels = list(reversed(self.encoder_channels[:-1])) if len(decoder_channels) != len(self.encoder_channels) - 1: raise ValueError("decoder_channels length must match len(encoder_channels) - 1.") in_channels = self.encoder_channels[-1] skip_channels = list(reversed(self.encoder_channels[:-1])) blocks = [] for skip_ch, out_ch in zip(skip_channels, decoder_channels): blocks.append(StructureAwareDecodeBlock2d(in_channels, skip_ch, out_ch)) in_channels = out_ch self.blocks = nn.ModuleList(blocks) self.out_channels = in_channels def forward( self, features: Sequence[torch.Tensor], stability_map: torch.Tensor | None = None, boundary_hints: Sequence[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, list[torch.Tensor]]: if len(features) != len(self.encoder_channels): raise ValueError( f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}" ) x = features[-1] skips = list(reversed(features[:-1])) decoder_features = [] if boundary_hints is None: boundary_hints = [None] * len(self.blocks) elif len(boundary_hints) != len(self.blocks): raise ValueError("boundary_hints length must match decoder depth.") for block, skip, boundary_hint in zip(self.blocks, skips, boundary_hints): x = block(x, skip, stability_map=stability_map, boundary_hint=boundary_hint) decoder_features.append(x) return x, decoder_features