decoder_2d.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from __future__ import annotations
  2. from typing import Sequence
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from .layers_2d import Conv2dBN
  7. class DecodeRefineBlock2d(nn.Module):
  8. """
  9. 对解码后的融合特征做轻量残差细化。
  10. """
  11. def __init__(self, channels: int) -> None:
  12. super().__init__()
  13. self.refine = nn.Sequential(
  14. Conv2dBN(channels, channels, 3, 1, 1),
  15. nn.ReLU(inplace=True),
  16. Conv2dBN(channels, channels, 3, 1, 1),
  17. )
  18. def forward(self, x: torch.Tensor) -> torch.Tensor:
  19. return x + self.refine(x)
  20. class SegmentationDecodeBlock2d(nn.Module):
  21. """
  22. 单层解码块:上采样高层特征,与 skip 特征融合后细化。
  23. """
  24. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
  25. super().__init__()
  26. self.high_proj = nn.Sequential(
  27. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  28. nn.ReLU(inplace=True),
  29. )
  30. self.skip_proj = nn.Sequential(
  31. Conv2dBN(skip_channels, out_channels, 1, 1, 0),
  32. nn.ReLU(inplace=True),
  33. )
  34. self.fuse = nn.Sequential(
  35. Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
  36. nn.ReLU(inplace=True),
  37. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  38. nn.ReLU(inplace=True),
  39. )
  40. self.refine = DecodeRefineBlock2d(out_channels)
  41. def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
  42. x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
  43. x = self.high_proj(x)
  44. skip = self.skip_proj(skip)
  45. x = self.fuse(torch.cat([x, skip], dim=1))
  46. return self.refine(x)
  47. class SegmentationDecoder2d(nn.Module):
  48. """
  49. 纯净的多尺度解码器骨架。
  50. 输入特征默认按从浅到深排列,最后一个特征视为最深层输入。
  51. """
  52. def __init__(
  53. self,
  54. encoder_channels: Sequence[int],
  55. decoder_channels: Sequence[int] | None = None,
  56. ) -> None:
  57. super().__init__()
  58. if len(encoder_channels) < 2:
  59. raise ValueError("encoder_channels must contain at least two stages.")
  60. self.encoder_channels = list(encoder_channels)
  61. if decoder_channels is None:
  62. decoder_channels = list(reversed(self.encoder_channels[:-1]))
  63. if len(decoder_channels) != len(self.encoder_channels) - 1:
  64. raise ValueError("decoder_channels length must match len(encoder_channels) - 1.")
  65. in_channels = self.encoder_channels[-1]
  66. skip_channels = list(reversed(self.encoder_channels[:-1]))
  67. blocks = []
  68. for skip_ch, out_ch in zip(skip_channels, decoder_channels):
  69. blocks.append(SegmentationDecodeBlock2d(in_channels, skip_ch, out_ch))
  70. in_channels = out_ch
  71. self.blocks = nn.ModuleList(blocks)
  72. self.out_channels = in_channels
  73. def forward(self, features: Sequence[torch.Tensor]) -> tuple[torch.Tensor, list[torch.Tensor]]:
  74. if len(features) != len(self.encoder_channels):
  75. raise ValueError(
  76. f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}"
  77. )
  78. x = features[-1]
  79. skips = list(reversed(features[:-1]))
  80. decoder_features = []
  81. for block, skip in zip(self.blocks, skips):
  82. x = block(x, skip)
  83. decoder_features.append(x)
  84. return x, decoder_features