augment.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from __future__ import annotations
  2. from typing import Any
  3. import torch
  4. def _rand_uniform(low: float, high: float) -> float:
  5. return float(torch.empty(1).uniform_(low, high).item())
  6. class SegmentationAugmentation:
  7. def __init__(self, config: dict[str, Any] | None = None) -> None:
  8. self.config = config or {}
  9. def __call__(
  10. self,
  11. image: torch.Tensor,
  12. mask: torch.Tensor | None = None,
  13. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  14. image, mask = self._apply_spatial(image, mask)
  15. image = self._apply_intensity(image)
  16. return image, mask
  17. def _apply_spatial(
  18. self,
  19. image: torch.Tensor,
  20. mask: torch.Tensor | None,
  21. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  22. if bool(self.config.get("random_flip", False)):
  23. if torch.rand(1).item() < 0.5:
  24. image = torch.flip(image, dims=(-1,))
  25. if mask is not None:
  26. mask = torch.flip(mask, dims=(-1,))
  27. if torch.rand(1).item() < 0.5:
  28. image = torch.flip(image, dims=(-2,))
  29. if mask is not None:
  30. mask = torch.flip(mask, dims=(-2,))
  31. if bool(self.config.get("random_rotate_90", False)):
  32. k = int(torch.randint(0, 4, (1,)).item())
  33. if k > 0:
  34. image = torch.rot90(image, k=k, dims=(-2, -1))
  35. if mask is not None:
  36. mask = torch.rot90(mask, k=k, dims=(-2, -1))
  37. return image, mask
  38. def _apply_intensity(self, image: torch.Tensor) -> torch.Tensor:
  39. if bool(self.config.get("random_brightness_contrast", False)):
  40. brightness = float(self.config.get("brightness_limit", 0.15))
  41. contrast = float(self.config.get("contrast_limit", 0.15))
  42. brightness_factor = _rand_uniform(1.0 - brightness, 1.0 + brightness)
  43. contrast_factor = _rand_uniform(1.0 - contrast, 1.0 + contrast)
  44. mean = image.mean(dim=(-2, -1), keepdim=True)
  45. image = (image - mean) * contrast_factor + mean
  46. image = image * brightness_factor
  47. if bool(self.config.get("random_gaussian_noise", False)):
  48. std = float(self.config.get("gaussian_noise_std", 0.03))
  49. if std > 0:
  50. image = image + torch.randn_like(image) * std
  51. return image.clamp(0.0, 1.0)
  52. def build_segmentation_augmentation(config: dict[str, Any] | None):
  53. if not config:
  54. return None
  55. return SegmentationAugmentation(config)
  56. __all__ = ["SegmentationAugmentation", "build_segmentation_augmentation"]