segmentation_2d.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from __future__ import annotations
  2. from argparse import Namespace
  3. from pathlib import Path
  4. from typing import Any, Sequence
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from .decoder_2d import SegmentationDecoder2d
  9. from .layers_2d import Conv2dBN
  10. from .swinv2_encoder_2d import SwinV2Encoder2d
  11. class SegmentationHead2d(nn.Module):
  12. def __init__(self, in_channels: int, out_channels: int) -> None:
  13. super().__init__()
  14. self.block = nn.Sequential(
  15. Conv2dBN(in_channels, in_channels, 3, 1, 1),
  16. nn.ReLU(inplace=True),
  17. nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
  18. )
  19. def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
  20. x = self.block(x)
  21. return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
  22. class SegmentationModel2d(nn.Module):
  23. """
  24. 简化后的超声分割网络。
  25. """
  26. def __init__(
  27. self,
  28. num_classes: int,
  29. model_name: str | None = None,
  30. config_path: str | Path | None = None,
  31. weight_path: str | Path | None = None,
  32. args: Namespace | None = None,
  33. *,
  34. decoder_channels: Sequence[int] | None = None,
  35. load_weights: bool = True,
  36. **encoder_kwargs: Any,
  37. ) -> None:
  38. super().__init__()
  39. self.encoder = SwinV2Encoder2d(
  40. model_name=model_name,
  41. config_path=config_path,
  42. weight_path=weight_path,
  43. args=args,
  44. load_weights=load_weights,
  45. **encoder_kwargs,
  46. )
  47. self.decoder = SegmentationDecoder2d(
  48. encoder_channels=self.encoder.stage_channels,
  49. decoder_channels=decoder_channels,
  50. )
  51. self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes)
  52. def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
  53. features = self.encoder(x)["features"]
  54. decoder_out, _ = self.decoder(features)
  55. output_size = x.shape[-2:]
  56. seg_logits = self.segmentation_head(decoder_out, output_size=output_size)
  57. return {"seg_logits": seg_logits}