|
@@ -421,41 +421,6 @@ class XNetEncoder2d(nn.Module):
|
|
|
return [e1, e2, e3, e4]
|
|
return [e1, e2, e3, e4]
|
|
|
|
|
|
|
|
|
|
|
|
|
-class XGuideProjector2d(nn.Module):
|
|
|
|
|
- # Guides are projected from encoder features and aligned to decoder resolution.
|
|
|
|
|
- def __init__(
|
|
|
|
|
- self, in_channels: int, out_channels: int, mode: str = "affine"
|
|
|
|
|
- ) -> None:
|
|
|
|
|
- super().__init__()
|
|
|
|
|
- self.mode = mode
|
|
|
|
|
- if mode == "affine":
|
|
|
|
|
- self.proj = nn.Sequential(
|
|
|
|
|
- Conv2dBN(in_channels, out_channels * 2, 1, 1, 0),
|
|
|
|
|
- nn.ReLU(inplace=True),
|
|
|
|
|
- nn.Conv2d(out_channels * 2, out_channels * 2, kernel_size=1, bias=True),
|
|
|
|
|
- )
|
|
|
|
|
- elif mode == "feature":
|
|
|
|
|
- self.proj = nn.Sequential(
|
|
|
|
|
- Conv2dBN(in_channels, out_channels, 1, 1, 0),
|
|
|
|
|
- nn.ReLU(inplace=True),
|
|
|
|
|
- )
|
|
|
|
|
- else:
|
|
|
|
|
- raise ValueError(f"Unsupported guide mode: {mode}")
|
|
|
|
|
-
|
|
|
|
|
- def forward(
|
|
|
|
|
- self,
|
|
|
|
|
- x: torch.Tensor,
|
|
|
|
|
- target_size: tuple[int, int],
|
|
|
|
|
- ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
- x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
|
|
|
|
|
- x = self.proj(x)
|
|
|
|
|
- if self.mode == "affine":
|
|
|
|
|
- gamma, beta = torch.chunk(x, 2, dim=1)
|
|
|
|
|
- gamma = torch.sigmoid(gamma) + 0.5
|
|
|
|
|
- return gamma, beta
|
|
|
|
|
- return x
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
class XSkipFusion2d(nn.Module):
|
|
class XSkipFusion2d(nn.Module):
|
|
|
# Decoder input and skip feature are aligned, projected, and fused together.
|
|
# Decoder input and skip feature are aligned, projected, and fused together.
|
|
|
def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
|
|
def __init__(self, in_channels: int, skip_channels: int, out_channels: int) -> None:
|
|
@@ -480,27 +445,6 @@ class XSkipFusion2d(nn.Module):
|
|
|
return self.fuse(torch.cat([x, skip], dim=1))
|
|
return self.fuse(torch.cat([x, skip], dim=1))
|
|
|
|
|
|
|
|
|
|
|
|
|
-class XGuideModulation2d(nn.Module):
|
|
|
|
|
- # Apply either direct affine guide or feature-to-affine modulation.
|
|
|
|
|
- def __init__(self, channels: int, guide_mode: str = "affine") -> None:
|
|
|
|
|
- super().__init__()
|
|
|
|
|
- self.guide_mode = guide_mode
|
|
|
|
|
- if guide_mode == "feature":
|
|
|
|
|
- self.to_affine = nn.Conv2d(channels, channels * 2, kernel_size=1, bias=True)
|
|
|
|
|
-
|
|
|
|
|
- def forward(
|
|
|
|
|
- self,
|
|
|
|
|
- x: torch.Tensor,
|
|
|
|
|
- guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
|
|
|
- ) -> torch.Tensor:
|
|
|
|
|
- if self.guide_mode == "affine":
|
|
|
|
|
- gamma, beta = guide
|
|
|
|
|
- else:
|
|
|
|
|
- gamma, beta = torch.chunk(self.to_affine(guide), 2, dim=1)
|
|
|
|
|
- gamma = torch.sigmoid(gamma) + 0.5
|
|
|
|
|
- return gamma * x + beta
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
class XFrequencyRefine2d(nn.Module):
|
|
class XFrequencyRefine2d(nn.Module):
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
@@ -597,14 +541,12 @@ class XFrequencyRefine2d(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
class XCRB2d(nn.Module):
|
|
class XCRB2d(nn.Module):
|
|
|
- # Decoder block: skip fusion -> guide modulation -> frequency refine -> residual output.
|
|
|
|
|
|
|
+ # Decoder block: U-Net skip fusion -> frequency refine -> residual output.
|
|
|
def __init__(
|
|
def __init__(
|
|
|
self,
|
|
self,
|
|
|
in_channels: int,
|
|
in_channels: int,
|
|
|
skip_channels: int,
|
|
skip_channels: int,
|
|
|
- guide_channels: int,
|
|
|
|
|
out_channels: int,
|
|
out_channels: int,
|
|
|
- guide_mode: str = "affine",
|
|
|
|
|
use_frequency_refine: bool = True,
|
|
use_frequency_refine: bool = True,
|
|
|
low_freq_radius_h: float = 0.25,
|
|
low_freq_radius_h: float = 0.25,
|
|
|
low_freq_radius_w: float = 0.25,
|
|
low_freq_radius_w: float = 0.25,
|
|
@@ -612,7 +554,6 @@ class XCRB2d(nn.Module):
|
|
|
) -> None:
|
|
) -> None:
|
|
|
super().__init__()
|
|
super().__init__()
|
|
|
self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
|
|
self.skip_fusion = XSkipFusion2d(in_channels, skip_channels, out_channels)
|
|
|
- self.guide_modulation = XGuideModulation2d(out_channels, guide_mode=guide_mode)
|
|
|
|
|
self.frequency_refine = (
|
|
self.frequency_refine = (
|
|
|
XFrequencyRefine2d(
|
|
XFrequencyRefine2d(
|
|
|
out_channels,
|
|
out_channels,
|
|
@@ -628,16 +569,13 @@ class XCRB2d(nn.Module):
|
|
|
nn.ReLU(inplace=True),
|
|
nn.ReLU(inplace=True),
|
|
|
Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
|
|
Conv2dBN(out_channels, out_channels, 3, 1, 1, bn_weight_init=0.0),
|
|
|
)
|
|
)
|
|
|
- self.guide_channels = guide_channels
|
|
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
self,
|
|
self,
|
|
|
x: torch.Tensor,
|
|
x: torch.Tensor,
|
|
|
skip: torch.Tensor,
|
|
skip: torch.Tensor,
|
|
|
- guide: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
|
|
|
) -> torch.Tensor:
|
|
) -> torch.Tensor:
|
|
|
x = self.skip_fusion(x, skip)
|
|
x = self.skip_fusion(x, skip)
|
|
|
- x = self.guide_modulation(x, guide)
|
|
|
|
|
x = x + self.frequency_refine(x)
|
|
x = x + self.frequency_refine(x)
|
|
|
return x + self.out_refine(x)
|
|
return x + self.out_refine(x)
|
|
|
|
|
|
|
@@ -677,15 +615,11 @@ class XNetDecoder2d(nn.Module):
|
|
|
raise ValueError("XNetDecoder2d expects 3 decoder channels.")
|
|
raise ValueError("XNetDecoder2d expects 3 decoder channels.")
|
|
|
c1, c2, c3, c4 = encoder_channels
|
|
c1, c2, c3, c4 = encoder_channels
|
|
|
d4, d3, d2 = decoder_channels
|
|
d4, d3, d2 = decoder_channels
|
|
|
- self.guide4 = XGuideProjector2d(c4, d4, mode=guide_mode)
|
|
|
|
|
- self.guide3 = XGuideProjector2d(c3, d3, mode=guide_mode)
|
|
|
|
|
- self.guide2 = XGuideProjector2d(c2, d2, mode=guide_mode)
|
|
|
|
|
|
|
+ self.guide_mode = guide_mode
|
|
|
self.dec4 = XCRB2d(
|
|
self.dec4 = XCRB2d(
|
|
|
c4,
|
|
c4,
|
|
|
c3,
|
|
c3,
|
|
|
d4,
|
|
d4,
|
|
|
- d4,
|
|
|
|
|
- guide_mode=guide_mode,
|
|
|
|
|
use_frequency_refine=use_frequency_refine,
|
|
use_frequency_refine=use_frequency_refine,
|
|
|
low_freq_radius_h=low_freq_radius_h,
|
|
low_freq_radius_h=low_freq_radius_h,
|
|
|
low_freq_radius_w=low_freq_radius_w,
|
|
low_freq_radius_w=low_freq_radius_w,
|
|
@@ -695,8 +629,6 @@ class XNetDecoder2d(nn.Module):
|
|
|
d4,
|
|
d4,
|
|
|
c2,
|
|
c2,
|
|
|
d3,
|
|
d3,
|
|
|
- d3,
|
|
|
|
|
- guide_mode=guide_mode,
|
|
|
|
|
use_frequency_refine=use_frequency_refine,
|
|
use_frequency_refine=use_frequency_refine,
|
|
|
low_freq_radius_h=low_freq_radius_h,
|
|
low_freq_radius_h=low_freq_radius_h,
|
|
|
low_freq_radius_w=low_freq_radius_w,
|
|
low_freq_radius_w=low_freq_radius_w,
|
|
@@ -706,8 +638,6 @@ class XNetDecoder2d(nn.Module):
|
|
|
d3,
|
|
d3,
|
|
|
c1,
|
|
c1,
|
|
|
d2,
|
|
d2,
|
|
|
- d2,
|
|
|
|
|
- guide_mode=guide_mode,
|
|
|
|
|
use_frequency_refine=use_frequency_refine,
|
|
use_frequency_refine=use_frequency_refine,
|
|
|
low_freq_radius_h=low_freq_radius_h,
|
|
low_freq_radius_h=low_freq_radius_h,
|
|
|
low_freq_radius_w=low_freq_radius_w,
|
|
low_freq_radius_w=low_freq_radius_w,
|
|
@@ -719,20 +649,13 @@ class XNetDecoder2d(nn.Module):
|
|
|
def forward(
|
|
def forward(
|
|
|
self,
|
|
self,
|
|
|
features: Sequence[torch.Tensor],
|
|
features: Sequence[torch.Tensor],
|
|
|
- ) -> tuple[
|
|
|
|
|
- torch.Tensor,
|
|
|
|
|
- list[torch.Tensor],
|
|
|
|
|
- list[torch.Tensor | tuple[torch.Tensor, torch.Tensor]],
|
|
|
|
|
- ]:
|
|
|
|
|
|
|
+ ) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
|
|
|
e1, e2, e3, e4 = features
|
|
e1, e2, e3, e4 = features
|
|
|
- g4 = self.guide4(e4, target_size=e3.shape[-2:])
|
|
|
|
|
- d4 = self.dec4(e4, e3, g4)
|
|
|
|
|
- g3 = self.guide3(e3, target_size=e2.shape[-2:])
|
|
|
|
|
- d3 = self.dec3(d4, e2, g3)
|
|
|
|
|
- g2 = self.guide2(e2, target_size=e1.shape[-2:])
|
|
|
|
|
- d2 = self.dec2(d3, e1, g2)
|
|
|
|
|
|
|
+ d4 = self.dec4(e4, e3)
|
|
|
|
|
+ d3 = self.dec3(d4, e2)
|
|
|
|
|
+ d2 = self.dec2(d3, e1)
|
|
|
d1 = self.head_refine(d2)
|
|
d1 = self.head_refine(d2)
|
|
|
- return d1, [d4, d3, d2, d1], [g4, g3, g2]
|
|
|
|
|
|
|
+ return d1, [d4, d3, d2, d1], []
|
|
|
|
|
|
|
|
|
|
|
|
|
class XNetSegHead2d(nn.Module):
|
|
class XNetSegHead2d(nn.Module):
|
|
@@ -824,18 +747,13 @@ class XNet2d(nn.Module):
|
|
|
|
|
|
|
|
def forward(
|
|
def forward(
|
|
|
self, x: torch.Tensor
|
|
self, x: torch.Tensor
|
|
|
- ) -> dict[
|
|
|
|
|
- str, torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]]
|
|
|
|
|
- ]:
|
|
|
|
|
|
|
+ ) -> dict[str, torch.Tensor | list[torch.Tensor]]:
|
|
|
encoder_features = self.encoder(x)
|
|
encoder_features = self.encoder(x)
|
|
|
encoder_features[-1] = self.bottleneck(encoder_features[-1])
|
|
encoder_features[-1] = self.bottleneck(encoder_features[-1])
|
|
|
decoder_out, decoder_features, guides = self.decoder(encoder_features)
|
|
decoder_out, decoder_features, guides = self.decoder(encoder_features)
|
|
|
output_size = x.shape[-2:]
|
|
output_size = x.shape[-2:]
|
|
|
logits = self.segmentation_head(decoder_out, output_size=output_size)
|
|
logits = self.segmentation_head(decoder_out, output_size=output_size)
|
|
|
- outputs: dict[
|
|
|
|
|
- str,
|
|
|
|
|
- torch.Tensor | list[torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]],
|
|
|
|
|
- ] = {
|
|
|
|
|
|
|
+ outputs: dict[str, torch.Tensor | list[torch.Tensor]] = {
|
|
|
"logits": logits,
|
|
"logits": logits,
|
|
|
"seg_logits": logits,
|
|
"seg_logits": logits,
|
|
|
"encoder_features": encoder_features,
|
|
"encoder_features": encoder_features,
|