| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- from __future__ import annotations
- from argparse import Namespace
- from pathlib import Path
- from typing import Any, Sequence
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from .decoder_2d import SegmentationDecoder2d
- from .layers_2d import Conv2dBN
- from .swinv2_encoder_2d import SwinV2Encoder2d
- class SegmentationHead2d(nn.Module):
- def __init__(self, in_channels: int, out_channels: int) -> None:
- super().__init__()
- self.block = nn.Sequential(
- Conv2dBN(in_channels, in_channels, 3, 1, 1),
- nn.ReLU(inplace=True),
- nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True),
- )
- def forward(self, x: torch.Tensor, output_size: tuple[int, int]) -> torch.Tensor:
- x = self.block(x)
- return F.interpolate(x, size=output_size, mode="bilinear", align_corners=False)
- class SegmentationModel2d(nn.Module):
- """
- 简化后的超声分割网络。
- """
- def __init__(
- self,
- num_classes: int,
- model_name: str | None = None,
- config_path: str | Path | None = None,
- weight_path: str | Path | None = None,
- args: Namespace | None = None,
- *,
- decoder_channels: Sequence[int] | None = None,
- load_weights: bool = True,
- **encoder_kwargs: Any,
- ) -> None:
- super().__init__()
- self.encoder = SwinV2Encoder2d(
- model_name=model_name,
- config_path=config_path,
- weight_path=weight_path,
- args=args,
- load_weights=load_weights,
- **encoder_kwargs,
- )
- self.decoder = SegmentationDecoder2d(
- encoder_channels=self.encoder.stage_channels,
- decoder_channels=decoder_channels,
- )
- self.segmentation_head = SegmentationHead2d(self.decoder.out_channels, num_classes)
- def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
- features = self.encoder(x)["features"]
- decoder_out, _ = self.decoder(features)
- output_size = x.shape[-2:]
- seg_logits = self.segmentation_head(decoder_out, output_size=output_size)
- return {"seg_logits": seg_logits}
|