| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107 |
- from __future__ import annotations
- from typing import Sequence
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .layers_2d import Conv2dBN
- class DecodeRefineBlock2d(nn.Module):
- """
- 对解码后的融合特征做轻量残差细化。
- """
- def __init__(self, channels: int) -> None:
- super().__init__()
- self.refine = nn.Sequential(
- Conv2dBN(channels, channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(channels, channels, 3, 1, 1),
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- return x + self.refine(x)
- class SegmentationDecodeBlock2d(nn.Module):
- """
- 单层解码块:上采样高层特征,与 skip 特征融合后细化。
- """
- def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
- super().__init__()
- self.high_proj = nn.Sequential(
- Conv2dBN(in_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.skip_proj = nn.Sequential(
- Conv2dBN(skip_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.fuse = nn.Sequential(
- Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(out_channels, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- )
- self.refine = DecodeRefineBlock2d(out_channels)
- def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
- x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
- x = self.high_proj(x)
- skip = self.skip_proj(skip)
- x = self.fuse(torch.cat([x, skip], dim=1))
- return self.refine(x)
- class SegmentationDecoder2d(nn.Module):
- """
- 纯净的多尺度解码器骨架。
- 输入特征默认按从浅到深排列,最后一个特征视为最深层输入。
- """
- def __init__(
- self,
- encoder_channels: Sequence[int],
- decoder_channels: Sequence[int] | None = None,
- ) -> None:
- super().__init__()
- if len(encoder_channels) < 2:
- raise ValueError("encoder_channels must contain at least two stages.")
- self.encoder_channels = list(encoder_channels)
- if decoder_channels is None:
- decoder_channels = list(reversed(self.encoder_channels[:-1]))
- if len(decoder_channels) != len(self.encoder_channels) - 1:
- raise ValueError("decoder_channels length must match len(encoder_channels) - 1.")
- in_channels = self.encoder_channels[-1]
- skip_channels = list(reversed(self.encoder_channels[:-1]))
- blocks = []
- for skip_ch, out_ch in zip(skip_channels, decoder_channels):
- blocks.append(SegmentationDecodeBlock2d(in_channels, skip_ch, out_ch))
- in_channels = out_ch
- self.blocks = nn.ModuleList(blocks)
- self.out_channels = in_channels
- def forward(self, features: Sequence[torch.Tensor]) -> tuple[torch.Tensor, list[torch.Tensor]]:
- if len(features) != len(self.encoder_channels):
- raise ValueError(
- f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}"
- )
- x = features[-1]
- skips = list(reversed(features[:-1]))
- decoder_features = []
- for block, skip in zip(self.blocks, skips):
- x = block(x, skip)
- decoder_features.append(x)
- return x, decoder_features
|