| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- from __future__ import annotations
- from typing import Any
- import torch
- def _rand_uniform(low: float, high: float) -> float:
- return float(torch.empty(1).uniform_(low, high).item())
- class SegmentationAugmentation:
- def __init__(self, config: dict[str, Any] | None = None) -> None:
- self.config = config or {}
- def __call__(
- self,
- image: torch.Tensor,
- mask: torch.Tensor | None = None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- image, mask = self._apply_spatial(image, mask)
- image = self._apply_intensity(image)
- return image, mask
- def _apply_spatial(
- self,
- image: torch.Tensor,
- mask: torch.Tensor | None,
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
- if bool(self.config.get("random_flip", False)):
- if torch.rand(1).item() < 0.5:
- image = torch.flip(image, dims=(-1,))
- if mask is not None:
- mask = torch.flip(mask, dims=(-1,))
- if torch.rand(1).item() < 0.5:
- image = torch.flip(image, dims=(-2,))
- if mask is not None:
- mask = torch.flip(mask, dims=(-2,))
- if bool(self.config.get("random_rotate_90", False)):
- k = int(torch.randint(0, 4, (1,)).item())
- if k > 0:
- image = torch.rot90(image, k=k, dims=(-2, -1))
- if mask is not None:
- mask = torch.rot90(mask, k=k, dims=(-2, -1))
- return image, mask
- def _apply_intensity(self, image: torch.Tensor) -> torch.Tensor:
- if bool(self.config.get("random_brightness_contrast", False)):
- brightness = float(self.config.get("brightness_limit", 0.15))
- contrast = float(self.config.get("contrast_limit", 0.15))
- brightness_factor = _rand_uniform(1.0 - brightness, 1.0 + brightness)
- contrast_factor = _rand_uniform(1.0 - contrast, 1.0 + contrast)
- mean = image.mean(dim=(-2, -1), keepdim=True)
- image = (image - mean) * contrast_factor + mean
- image = image * brightness_factor
- if bool(self.config.get("random_gaussian_noise", False)):
- std = float(self.config.get("gaussian_noise_std", 0.03))
- if std > 0:
- image = image + torch.randn_like(image) * std
- return image.clamp(0.0, 1.0)
- def build_segmentation_augmentation(config: dict[str, Any] | None):
- if not config:
- return None
- return SegmentationAugmentation(config)
- __all__ = ["SegmentationAugmentation", "build_segmentation_augmentation"]
|