from __future__ import annotations from typing import Any import torch def _rand_uniform(low: float, high: float) -> float: return float(torch.empty(1).uniform_(low, high).item()) class SegmentationAugmentation: def __init__(self, config: dict[str, Any] | None = None) -> None: self.config = config or {} def __call__( self, image: torch.Tensor, mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: image, mask = self._apply_spatial(image, mask) image = self._apply_intensity(image) return image, mask def _apply_spatial( self, image: torch.Tensor, mask: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: if bool(self.config.get("random_flip", False)): if torch.rand(1).item() < 0.5: image = torch.flip(image, dims=(-1,)) if mask is not None: mask = torch.flip(mask, dims=(-1,)) if torch.rand(1).item() < 0.5: image = torch.flip(image, dims=(-2,)) if mask is not None: mask = torch.flip(mask, dims=(-2,)) if bool(self.config.get("random_rotate_90", False)): k = int(torch.randint(0, 4, (1,)).item()) if k > 0: image = torch.rot90(image, k=k, dims=(-2, -1)) if mask is not None: mask = torch.rot90(mask, k=k, dims=(-2, -1)) return image, mask def _apply_intensity(self, image: torch.Tensor) -> torch.Tensor: if bool(self.config.get("random_brightness_contrast", False)): brightness = float(self.config.get("brightness_limit", 0.15)) contrast = float(self.config.get("contrast_limit", 0.15)) brightness_factor = _rand_uniform(1.0 - brightness, 1.0 + brightness) contrast_factor = _rand_uniform(1.0 - contrast, 1.0 + contrast) mean = image.mean(dim=(-2, -1), keepdim=True) image = (image - mean) * contrast_factor + mean image = image * brightness_factor if bool(self.config.get("random_gaussian_noise", False)): std = float(self.config.get("gaussian_noise_std", 0.03)) if std > 0: image = image + torch.randn_like(image) * std return image.clamp(0.0, 1.0) def build_segmentation_augmentation(config: dict[str, Any] | None): if not config: return None return SegmentationAugmentation(config) __all__ = ["SegmentationAugmentation", "build_segmentation_augmentation"]