from __future__ import annotations import torch import torch.nn.functional as F def _ensure_nchw(mask: torch.Tensor) -> torch.Tensor: if mask.ndim == 3: return mask.unsqueeze(1) if mask.ndim != 4: raise ValueError(f"Expected mask with 3 or 4 dims, got shape {tuple(mask.shape)}") return mask def mask_to_boundary_map(mask: torch.Tensor, dilation: int = 1) -> torch.Tensor: """ 通过最大池化近似形态学梯度,生成边界图。 """ mask = _ensure_nchw(mask).float() kernel_size = dilation * 2 + 1 pad = dilation dilated = F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=pad) eroded = -F.max_pool2d(-mask, kernel_size=kernel_size, stride=1, padding=pad) boundary = (dilated - eroded).clamp_min(0.0) return (boundary > 0).float() def logits_to_binary_mask(logits: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: if logits.shape[1] == 1: probs = torch.sigmoid(logits) return (probs >= threshold).float() preds = torch.argmax(logits, dim=1, keepdim=True) return preds.float() def logits_to_boundary(logits: torch.Tensor, threshold: float = 0.5, dilation: int = 1) -> torch.Tensor: mask = logits_to_binary_mask(logits, threshold=threshold) return mask_to_boundary_map(mask, dilation=dilation) def boundary_band_map(boundary: torch.Tensor, radius: int = 2) -> torch.Tensor: boundary = _ensure_nchw(boundary).float() kernel_size = radius * 2 + 1 return F.max_pool2d(boundary, kernel_size=kernel_size, stride=1, padding=radius) __all__ = [ "mask_to_boundary_map", "logits_to_binary_mask", "logits_to_boundary", "boundary_band_map", ]