浏览代码

refactor(xnet_2d): 移除引导投影器和调制模块,简化解码器结构

移除了 XGuideProjector2d 和 XGuideModulation2d 模块,
简化了 XCRB2d 解码器块的设计,去除了引导机制,
使解码器仅使用标准的U-Net跳跃连接融合方式。
更新了相关的导入语句和导出列表,移除了未使用的组件引用。
kekezack 1 周之前
父节点
当前提交
3efd259ee3
共有 3 个文件被更改,包括 30 次插入93 次删除
  1. 0 2
      lib/modules/__init__.py
  2. 9 91
      lib/modules/xnet_2d.py
  3. 21 0
      tests/test_xnet_2d.py

+ 0 - 2
lib/modules/__init__.py

@@ -10,7 +10,6 @@ from .layers_2d import (
 )
 from .xnet_2d import (
     XCRB2d,
-    XGuideProjector2d,
     XNet2d,
     XNetDecoder2d,
     XNetDownsample2d,
@@ -32,7 +31,6 @@ __all__ = [
     "Residual",
     "Scale",
     "XCRB2d",
-    "XGuideProjector2d",
     "XNet2d",
     "XNetDecoder2d",
     "XNetDownsample2d",

+ 9 - 91
lib/modules/xnet_2d.py

@@ -421,41 +421,6 @@ class XNetEncoder2d(nn.Module):
         return [e1, e2, e3, e4]
 
 
-class XGuideProjector2d(nn.Module):
-    # Guides are projected from encoder features and aligned to decoder resolution.
-    def __init__(
-        self, in_channels: int, out_channels: int, mode: str = "affine"
-    ) -> None:
-        super().__init__()
-        self.mode = mode
-        if mode == "affine":
-            self.proj = nn.Sequential(
-                Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
-                nn.ReLU(inplace=True),
-                nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
-            )
-        elif mode == "feature":
-            self.proj = nn.Sequential(
-                Conv2dBN(in_channels, out_channels, 1, 1, 0),
-                nn.ReLU(inplace=True),
-            )
-        else:
-            raise ValueError(f"Unsupported guide mode: {mode}")
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        target_size: tuple[int, int],
-    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
-        x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
-        x = self.proj(x)
-        if self.mode == "affine":
-            gamma, beta = torch.chunk(x, 2, dim=1)
-            gamma = torch.sigmoid(gamma) + 0.5
-            return gamma, beta
-        return x
-
-
 class XSkipFusion2d(nn.Module):
     # Decoder input and skip feature are aligned, projected, and fused together.
     def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
@@ -480,27 +445,6 @@ class XSkipFusion2d(nn.Module):
         return self.fuse(torch.cat([x, skip], dim=1))
 
 
-class XGuideModulation2d(nn.Module):
-    # Apply either direct affine guide or feature-to-affine modulation.
-    def __init__(self, channels: int, guide_mode: str = "affine") -> None:
-        super().__init__()
-        self.guide_mode = guide_mode
-        if guide_mode == "feature":
-            self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
-    ) -> torch.Tensor:
-        if self.guide_mode == "affine":
-            gamma, beta = guide
-        else:
-            gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
-            gamma = torch.sigmoid(gamma) + 0.5
-        return gamma * x + beta
-
-
 class XFrequencyRefine2d(nn.Module):
     def __init__(
         self,
@@ -597,14 +541,12 @@ class XFrequencyRefine2d(nn.Module):
 
 
 class XCRB2d(nn.Module):
-    # Decoder block: skip fusion -> guide modulation -> frequency refine -> residual output.
+    # Decoder block: U-Net skip fusion -> frequency refine -> residual output.
     def __init__(
         self,
         in_channels: int,
         skip_channels: int,
-        guide_channels: int,
         out_channels: int,
-        guide_mode: str = "affine",
         use_frequency_refine: bool = True,
         low_freq_radius_h: float = 0.25,
         low_freq_radius_w: float = 0.25,
@@ -612,7 +554,6 @@ class XCRB2d(nn.Module):
     ) -> None:
         super().__init__()
         self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
-        self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
         self.frequency_refine = (
             XFrequencyRefine2d(
                 out_channels,
@@ -628,16 +569,13 @@ class XCRB2d(nn.Module):
             nn.ReLU(inplace=True),
             Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
         )
-        self.guide_channels = guide_channels
 
     def forward(
         self,
         x: torch.Tensor,
         skip: torch.Tensor,
-        guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
     ) -> torch.Tensor:
         x = self.skip_fusion(x, skip)
-        x = self.guide_modulation(x, guide)
         x = x + self.frequency_refine(x)
         return x + self.out_refine(x)
 
@@ -677,15 +615,11 @@ class XNetDecoder2d(nn.Module):
             raise ValueError("XNetDecoder2d expects 3 decoder channels.")
         c1, c2, c3, c4 = encoder_channels
         d4, d3, d2 = decoder_channels
-        self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
-        self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
-        self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
+        self.guide_mode = guide_mode
         self.dec4 = XCRB2d(
             c4,
             c3,
             d4,
-            d4,
-            guide_mode=guide_mode,
             use_frequency_refine=use_frequency_refine,
             low_freq_radius_h=low_freq_radius_h,
             low_freq_radius_w=low_freq_radius_w,
@@ -695,8 +629,6 @@ class XNetDecoder2d(nn.Module):
             d4,
             c2,
             d3,
-            d3,
-            guide_mode=guide_mode,
             use_frequency_refine=use_frequency_refine,
             low_freq_radius_h=low_freq_radius_h,
             low_freq_radius_w=low_freq_radius_w,
@@ -706,8 +638,6 @@ class XNetDecoder2d(nn.Module):
             d3,
             c1,
             d2,
-            d2,
-            guide_mode=guide_mode,
             use_frequency_refine=use_frequency_refine,
             low_freq_radius_h=low_freq_radius_h,
             low_freq_radius_w=low_freq_radius_w,
@@ -719,20 +649,13 @@ class XNetDecoder2d(nn.Module):
     def forward(
         self,
         features: Sequence[torch.Tensor],
-    ) -> tuple[
-        torch.Tensor,
-        list[torch.Tensor],
-        list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
-    ]:
+    ) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
         e1, e2, e3, e4 = features
-        g4 = self.guide4(e4, target_size=e3.shape[-2:])
-        d4 = self.dec4(e4, e3, g4)
-        g3 = self.guide3(e3, target_size=e2.shape[-2:])
-        d3 = self.dec3(d4, e2, g3)
-        g2 = self.guide2(e2, target_size=e1.shape[-2:])
-        d2 = self.dec2(d3, e1, g2)
+        d4 = self.dec4(e4, e3)
+        d3 = self.dec3(d4, e2)
+        d2 = self.dec2(d3, e1)
         d1 = self.head_refine(d2)
-        return d1, [d4, d3, d2, d1], [g4, g3, g2]
+        return d1, [d4, d3, d2, d1], []
 
 
 class XNetSegHead2d(nn.Module):
@@ -824,18 +747,13 @@ class XNet2d(nn.Module):
 
     def forward(
         self, x: torch.Tensor
-    ) -> dict[
-        str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
-    ]:
+    ) -> dict[str, torch.Tensor | list[torch.Tensor]]:
         encoder_features = self.encoder(x)
         encoder_features[-1] = self.bottleneck(encoder_features[-1])
         decoder_out, decoder_features, guides = self.decoder(encoder_features)
         output_size = x.shape[-2:]
         logits = self.segmentation_head(decoder_out, output_size=output_size)
-        outputs: dict[
-            str,
-            torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
-        ] = {
+        outputs: dict[str, torch.Tensor | list[torch.Tensor]] = {
             "logits": logits,
             "seg_logits": logits,
             "encoder_features": encoder_features,

+ 21 - 0
tests/test_xnet_2d.py

@@ -60,3 +60,24 @@ def test_xnet2d_forward_preserves_segmentation_shape() -> None:
 
     assert outputs["seg_logits"].shape == (2, 1, 64, 64)
     assert outputs["logits"].shape == outputs["seg_logits"].shape
+
+
+def test_xnet2d_decoder_uses_plain_unet_skip_connections() -> None:
+    from lib.modules.xnet_2d import XNet2d
+
+    model = XNet2d(
+        in_channels=3,
+        num_classes=1,
+        encoder_channels=(8, 16, 24, 32),
+        encoder_depths=(1, 1, 1, 1),
+        decoder_channels=(24, 16, 8),
+        stem_channels=8,
+        bottleneck_depth=1,
+        use_global_branch_stage1=False,
+        ssm_d_state=1,
+        ssm_backend="torch",
+    )
+
+    decoder_module_names = dict(model.decoder.named_modules())
+
+    assert not any(name.startswith("guide") for name in decoder_module_names)