decoder_2d.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from __future__ import annotations
  2. from typing import Sequence
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from .layers_2d import Conv2dBN
  7. class BoundaryRefineBlock2d(nn.Module):
  8. """
  9. 使用边界提示和稳定性图对解码特征做轻量细化。
  10. """
  11. def __init__(self, channels: int) -> None:
  12. super().__init__()
  13. self.refine = nn.Sequential(
  14. Conv2dBN(channels, channels, 3, 1, 1),
  15. nn.ReLU(inplace=True),
  16. Conv2dBN(channels, channels, 3, 1, 1),
  17. )
  18. def forward(
  19. self,
  20. x: torch.Tensor,
  21. boundary_hint: torch.Tensor | None = None,
  22. stability_map: torch.Tensor | None = None,
  23. ) -> torch.Tensor:
  24. modulator = 1.0
  25. if stability_map is not None:
  26. stability_map = F.interpolate(
  27. stability_map, size=x.shape[-2:], mode="bilinear", align_corners=False
  28. )
  29. modulator = modulator + stability_map
  30. if boundary_hint is not None:
  31. boundary_hint = F.interpolate(
  32. boundary_hint, size=x.shape[-2:], mode="bilinear", align_corners=False
  33. )
  34. modulator = modulator + boundary_hint
  35. return x + self.refine(x * modulator)
  36. class StructureAwareDecodeBlock2d(nn.Module):
  37. """
  38. 单层结构感知解码块。
  39. """
  40. def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
  41. super().__init__()
  42. self.high_proj = nn.Sequential(
  43. Conv2dBN(in_channels, out_channels, 1, 1, 0),
  44. nn.ReLU(inplace=True),
  45. )
  46. self.skip_proj = nn.Sequential(
  47. Conv2dBN(skip_channels, out_channels, 1, 1, 0),
  48. nn.ReLU(inplace=True),
  49. )
  50. self.fuse = nn.Sequential(
  51. Conv2dBN(out_channels * 2, out_channels, 3, 1, 1),
  52. nn.ReLU(inplace=True),
  53. Conv2dBN(out_channels, out_channels, 3, 1, 1),
  54. nn.ReLU(inplace=True),
  55. )
  56. self.refine = BoundaryRefineBlock2d(out_channels)
  57. def forward(
  58. self,
  59. x: torch.Tensor,
  60. skip: torch.Tensor,
  61. stability_map: torch.Tensor | None = None,
  62. boundary_hint: torch.Tensor | None = None,
  63. ) -> torch.Tensor:
  64. x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=False)
  65. x = self.high_proj(x)
  66. skip = self.skip_proj(skip)
  67. x = self.fuse(torch.cat([x, skip], dim=1))
  68. x = self.refine(x, boundary_hint=boundary_hint, stability_map=stability_map)
  69. return x
  70. class StructureAwareDecoder2d(nn.Module):
  71. """
  72. 第一版结构感知解码器骨架。
  73. 输入特征默认按从浅到深排列,最后一个特征视为最深层输入。
  74. """
  75. def __init__(self, encoder_channels: Sequence[int], decoder_channels: Sequence[int] | None = None) -> None:
  76. super().__init__()
  77. if len(encoder_channels) < 2:
  78. raise ValueError("encoder_channels must contain at least two stages.")
  79. self.encoder_channels = list(encoder_channels)
  80. if decoder_channels is None:
  81. decoder_channels = list(reversed(self.encoder_channels[:-1]))
  82. if len(decoder_channels) != len(self.encoder_channels) - 1:
  83. raise ValueError("decoder_channels length must match len(encoder_channels) - 1.")
  84. in_channels = self.encoder_channels[-1]
  85. skip_channels = list(reversed(self.encoder_channels[:-1]))
  86. blocks = []
  87. for skip_ch, out_ch in zip(skip_channels, decoder_channels):
  88. blocks.append(StructureAwareDecodeBlock2d(in_channels, skip_ch, out_ch))
  89. in_channels = out_ch
  90. self.blocks = nn.ModuleList(blocks)
  91. self.out_channels = in_channels
  92. def forward(
  93. self,
  94. features: Sequence[torch.Tensor],
  95. stability_map: torch.Tensor | None = None,
  96. boundary_hints: Sequence[torch.Tensor] | None = None,
  97. ) -> tuple[torch.Tensor, list[torch.Tensor]]:
  98. if len(features) != len(self.encoder_channels):
  99. raise ValueError(
  100. f"feature count mismatch: got {len(features)}, expected {len(self.encoder_channels)}"
  101. )
  102. x = features[-1]
  103. skips = list(reversed(features[:-1]))
  104. decoder_features = []
  105. if boundary_hints is None:
  106. boundary_hints = [None] * len(self.blocks)
  107. elif len(boundary_hints) != len(self.blocks):
  108. raise ValueError("boundary_hints length must match decoder depth.")
  109. for block, skip, boundary_hint in zip(self.blocks, skips, boundary_hints):
  110. x = block(x, skip, stability_map=stability_map, boundary_hint=boundary_hint)
  111. decoder_features.append(x)
  112. return x, decoder_features