boundary.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from __future__ import annotations
  2. import torch
  3. import torch.nn.functional as F
  4. def _ensure_nchw(mask: torch.Tensor) -> torch.Tensor:
  5. if mask.ndim == 3:
  6. return mask.unsqueeze(1)
  7. if mask.ndim != 4:
  8. raise ValueError(f"Expected mask with 3 or 4 dims, got shape {tuple(mask.shape)}")
  9. return mask
  10. def mask_to_boundary_map(mask: torch.Tensor, dilation: int = 1) -> torch.Tensor:
  11. """
  12. 通过最大池化近似形态学梯度,生成边界图。
  13. """
  14. mask = _ensure_nchw(mask).float()
  15. kernel_size = dilation * 2 + 1
  16. pad = dilation
  17. dilated = F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=pad)
  18. eroded = -F.max_pool2d(-mask, kernel_size=kernel_size, stride=1, padding=pad)
  19. boundary = (dilated - eroded).clamp_min(0.0)
  20. return (boundary > 0).float()
  21. def logits_to_binary_mask(logits: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
  22. if logits.shape[1] == 1:
  23. probs = torch.sigmoid(logits)
  24. return (probs >= threshold).float()
  25. preds = torch.argmax(logits, dim=1, keepdim=True)
  26. return preds.float()
  27. def logits_to_boundary(logits: torch.Tensor, threshold: float = 0.5, dilation: int = 1) -> torch.Tensor:
  28. mask = logits_to_binary_mask(logits, threshold=threshold)
  29. return mask_to_boundary_map(mask, dilation=dilation)
  30. def boundary_band_map(boundary: torch.Tensor, radius: int = 2) -> torch.Tensor:
  31. boundary = _ensure_nchw(boundary).float()
  32. kernel_size = radius * 2 + 1
  33. return F.max_pool2d(boundary, kernel_size=kernel_size, stride=1, padding=radius)
  34. __all__ = [
  35. "mask_to_boundary_map",
  36. "logits_to_binary_mask",
  37. "logits_to_boundary",
  38. "boundary_band_map",
  39. ]