| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- from __future__ import annotations
- from typing import Sequence
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .layers_2d import Conv2dBN
- class BoundaryRefineBlock2d(nn.Module):
- """
- 使用边界提示和稳定性图对解码特征做轻量细化。
- """
- def __init__(self, channels: int) -> None:
- super().__init__()
- self.refine = nn.Sequential(
- Conv2dBN(channels, channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(channels, channels, 3, 1, 1),
- )
- def forward(
- self,
- x: torch.Tensor,
- boundary_hint: torch.Tensor | None = None,
- stability_map: torch.Tensor | None = None,
- ) -> torch.Tensor:
- modulator = 1.0
- if stability_map is not None:
- stability_map = F.interpolate(
- stability_map, size=x.shape[-2:], mode="bilinear", align_corners=False
- )
- modulator = modulator + stability_map
- if boundary_hint is not None:
- boundary_hint = F.interpolate(
- boundary_hint, size=x.shape[-2:], mode="bilinear", align_corners=False
- )
- modulator = modulator + boundary_hint
- return x + self.refine(x * modulator)
- class StructureAwareDecodeBlock2d(nn.Module):
- """
- 单层结构感知解码块。
- """
- def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
- super().__init__()
- self.high_proj = nn.Sequential(
- Conv2dBN(in_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.skip_proj = nn.Sequential(
- Conv2dBN(skip_channels, out_channels, 1, 1, 0),
- nn.ReLU(inplace=True),
- )
- self.fuse = nn.Sequential(
- Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- Conv2dBN(out_channels, out_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- )
- self.refine = BoundaryRefineBlock2d(out_channels)
- def forward(
- self,
- x: torch.Tensor,
- skip: torch.Tensor,
- stability_map: torch.Tensor | None = None,
- boundary_hint: torch.Tensor | None = None,
- ) -> torch.Tensor:
- x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
- x = self.high_proj(x)
- skip = self.skip_proj(skip)
- x = self.fuse(torch.cat([x, skip], dim=1))
- x = self.refine(x, boundary_hint=boundary_hint, stability_map=stability_map)
- return x
- class StructureAwareDecoder2d(nn.Module):
- """
- 第一版结构感知解码器骨架。
- 输入特征默认按从浅到深排列,最后一个特征视为最深层输入。
- """
- def __init__(self, encoder_channels: Sequence[int], decoder_channels: Sequence[int] | None = None) -> None:
- super().__init__()
- if len(encoder_channels) < 2:
- raise ValueError("encoder_channels must contain at least two stages.")
- self.encoder_channels = list(encoder_channels)
- if decoder_channels is None:
- decoder_channels = list(reversed(self.encoder_channels[:-1]))
- if len(decoder_channels) != len(self.encoder_channels) - 1:
- raise ValueError("decoder_channels length must match len(encoder_channels) - 1.")
- in_channels = self.encoder_channels[-1]
- skip_channels = list(reversed(self.encoder_channels[:-1]))
- blocks = []
- for skip_ch, out_ch in zip(skip_channels, decoder_channels):
- blocks.append(StructureAwareDecodeBlock2d(in_channels, skip_ch, out_ch))
- in_channels = out_ch
- self.blocks = nn.ModuleList(blocks)
- self.out_channels = in_channels
- def forward(
- self,
- features: Sequence[torch.Tensor],
- stability_map: torch.Tensor | None = None,
- boundary_hints: Sequence[torch.Tensor] | None = None,
- ) -> tuple[torch.Tensor, list[torch.Tensor]]:
- if len(features) != len(self.encoder_channels):
- raise ValueError(
- f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}"
- )
- x = features[-1]
- skips = list(reversed(features[:-1]))
- decoder_features = []
- if boundary_hints is None:
- boundary_hints = [None] * len(self.blocks)
- elif len(boundary_hints) != len(self.blocks):
- raise ValueError("boundary_hints length must match decoder depth.")
- for block, skip, boundary_hint in zip(self.blocks, skips, boundary_hints):
- x = block(x, skip, stability_map=stability_map, boundary_hint=boundary_hint)
- decoder_features.append(x)
- return x, decoder_features
|