| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- from __future__ import annotations
- import torch
- import torch.nn.functional as F
- def _ensure_nchw(mask: torch.Tensor) -> torch.Tensor:
- if mask.ndim == 3:
- return mask.unsqueeze(1)
- if mask.ndim != 4:
- raise ValueError(f"Expected mask with 3 or 4 dims, got shape {tuple(mask.shape)}")
- return mask
- def mask_to_boundary_map(mask: torch.Tensor, dilation: int = 1) -> torch.Tensor:
- """
- 通过最大池化近似形态学梯度,生成边界图。
- """
- mask = _ensure_nchw(mask).float()
- kernel_size = dilation * 2 + 1
- pad = dilation
- dilated = F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=pad)
- eroded = -F.max_pool2d(-mask, kernel_size=kernel_size, stride=1, padding=pad)
- boundary = (dilated - eroded).clamp_min(0.0)
- return (boundary > 0).float()
- def logits_to_binary_mask(logits: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
- if logits.shape[1] == 1:
- probs = torch.sigmoid(logits)
- return (probs >= threshold).float()
- preds = torch.argmax(logits, dim=1, keepdim=True)
- return preds.float()
- def logits_to_boundary(logits: torch.Tensor, threshold: float = 0.5, dilation: int = 1) -> torch.Tensor:
- mask = logits_to_binary_mask(logits, threshold=threshold)
- return mask_to_boundary_map(mask, dilation=dilation)
- def boundary_band_map(boundary: torch.Tensor, radius: int = 2) -> torch.Tensor:
- boundary = _ensure_nchw(boundary).float()
- kernel_size = radius * 2 + 1
- return F.max_pool2d(boundary, kernel_size=kernel_size, stride=1, padding=radius)
- __all__ = [
- "mask_to_boundary_map",
- "logits_to_binary_mask",
- "logits_to_boundary",
- "boundary_band_map",
- ]
|