supervised.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import time
  3. from typing import Any
  4. import torch
  5. from torch.utils.data import DataLoader
  6. from lib.modules import SegmentationNet2d
  7. from lib.tools import (
  8. BinaryBoundaryLoss,
  9. MaskBoundaryConsistencyLoss,
  10. build_optimizer,
  11. build_scheduler,
  12. mask_to_boundary_map,
  13. )
  14. from .base import BaseTrainer
  15. class SupervisedSegmentationTrainer(BaseTrainer):
  16. def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
  17. super().__init__(cfg=cfg, args=args)
  18. self.model: SegmentationNet2d | None = None
  19. self.optimizer = None
  20. self.scheduler = None
  21. self.loader: DataLoader | None = None
  22. self.val_loader: DataLoader | None = None
  23. self.seg_loss = None
  24. self.boundary_loss = BinaryBoundaryLoss()
  25. self.consistency_loss = MaskBoundaryConsistencyLoss()
  26. def build(self) -> None:
  27. dataset_cfg = self.cfg["dataset"]
  28. model_cfg = self.cfg["model"]
  29. train_cfg = self.cfg["train"]
  30. self.model = SegmentationNet2d(
  31. num_classes=dataset_cfg["num_classes"],
  32. model_name=model_cfg["model_name"],
  33. load_weights=model_cfg.get("load_weights", False),
  34. decoder_channels=model_cfg.get("decoder_channels"),
  35. fwta_wavelet=model_cfg.get("fwta_wavelet", "haar"),
  36. fwta_level=model_cfg.get("fwta_level", 1),
  37. fwta_sigma_ratio=model_cfg.get("fwta_sigma_ratio", 0.35),
  38. fwta_tau_fourier=model_cfg.get("fwta_tau_fourier", 0.15),
  39. fwta_gate_temperature=model_cfg.get("fwta_gate_temperature", 1.0),
  40. ).to(self.device)
  41. self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
  42. self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
  43. self.loader = self._build_segmentation_loader(
  44. split=str(dataset_cfg.get("split", "train")),
  45. split_file=dataset_cfg.get("split_file"),
  46. batch_size=self._resolve_batch_size("batch_size", 4),
  47. shuffle=bool(train_cfg.get("shuffle", True)),
  48. )
  49. self.val_loader = self._build_val_loader(
  50. batch_size=self._resolve_batch_size(
  51. "val_batch_size",
  52. int(train_cfg.get("batch_size", 4)),
  53. ),
  54. shuffle=False,
  55. )
  56. self._maybe_resume(
  57. module_map={"model": self.model},
  58. optimizer=self.optimizer,
  59. scheduler=self.scheduler,
  60. )
  61. self._init_swanlab()
  62. def _compute_losses(
  63. self,
  64. image: torch.Tensor,
  65. mask: torch.Tensor,
  66. ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
  67. if self.model is None:
  68. raise RuntimeError("Model is not initialized.")
  69. with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
  70. outputs = self.model(image)
  71. seg_logits = outputs["seg_logits"]
  72. boundary_logits = outputs["boundary_logits"]
  73. seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
  74. boundary_target = mask_to_boundary_map(mask)
  75. boundary_loss = self.boundary_loss(boundary_logits, boundary_target)
  76. consistency_loss = self.consistency_loss(seg_logits, boundary_logits)
  77. total_loss = seg_loss + boundary_loss + 0.1 * consistency_loss
  78. losses = {
  79. "total": total_loss,
  80. "seg": seg_loss,
  81. "boundary": boundary_loss,
  82. "consistency": consistency_loss,
  83. }
  84. return outputs, losses
  85. @staticmethod
  86. def _detach_metrics(losses: dict[str, torch.Tensor]) -> dict[str, float]:
  87. return {key: float(value.detach().cpu()) for key, value in losses.items()}
  88. def _validate(self) -> dict[str, float] | None:
  89. if self.model is None or self.val_loader is None:
  90. return None
  91. self.model.eval()
  92. metrics = self._build_validation_metrics()
  93. total = 0.0
  94. seg = 0.0
  95. boundary = 0.0
  96. consistency = 0.0
  97. steps = 0
  98. with torch.no_grad():
  99. for batch in self.val_loader:
  100. image = batch["image"].to(self.device)
  101. mask = batch["mask"].to(self.device)
  102. outputs, losses = self._compute_losses(image, mask)
  103. total += float(losses["total"].detach().cpu())
  104. seg += float(losses["seg"].detach().cpu())
  105. boundary += float(losses["boundary"].detach().cpu())
  106. consistency += float(losses["consistency"].detach().cpu())
  107. self._update_validation_metrics(
  108. metrics,
  109. logits=outputs["seg_logits"],
  110. target=mask,
  111. )
  112. steps += 1
  113. if steps == 0:
  114. return None
  115. val_metrics = {
  116. "total": total / steps,
  117. "seg": seg / steps,
  118. "boundary": boundary / steps,
  119. "consistency": consistency / steps,
  120. }
  121. val_metrics.update(self._compute_validation_metric_values(metrics))
  122. return val_metrics
  123. def train(self) -> None:
  124. if self.model is None or self.loader is None or self.optimizer is None:
  125. raise RuntimeError("Trainer.build() must be called before train().")
  126. epochs = int(self.cfg["train"].get("epochs", 1))
  127. try:
  128. self._print_training_setup(
  129. model_map={"model": self.model},
  130. loader_map={"train": self.loader, "val": self.val_loader},
  131. optimizer=self.optimizer,
  132. scheduler=self.scheduler,
  133. )
  134. for epoch in range(self.start_epoch, epochs):
  135. self.model.train()
  136. train_metric_sums = {
  137. "total": 0.0,
  138. "seg": 0.0,
  139. "boundary": 0.0,
  140. "consistency": 0.0,
  141. }
  142. train_metrics: dict[str, float] | None = None
  143. end_time = time.perf_counter()
  144. num_steps = len(self.loader)
  145. for step, batch in enumerate(self.loader, start=1):
  146. data_time = time.perf_counter() - end_time
  147. iter_start = time.perf_counter()
  148. image = batch["image"].to(self.device)
  149. mask = batch["mask"].to(self.device)
  150. _, losses = self._compute_losses(image, mask)
  151. self.optimizer.zero_grad()
  152. self.grad_scaler.scale(losses["total"]).backward()
  153. grad_norm = None
  154. if self._grad_clip_enabled():
  155. self.grad_scaler.unscale_(self.optimizer)
  156. grad_norm = self._clip_gradients(self.model)
  157. self.grad_scaler.step(self.optimizer)
  158. self.grad_scaler.update()
  159. train_metrics = self._detach_metrics(losses)
  160. if grad_norm is not None:
  161. train_metrics["grad_norm"] = grad_norm
  162. for key, value in train_metrics.items():
  163. train_metric_sums.setdefault(key, 0.0)
  164. train_metric_sums[key] += value
  165. iter_time = time.perf_counter() - iter_start
  166. self._maybe_log_step(
  167. epoch=epoch,
  168. step=step,
  169. num_steps=num_steps,
  170. data_time=data_time,
  171. iter_time=iter_time,
  172. metrics=train_metrics,
  173. prefix="train",
  174. )
  175. end_time = time.perf_counter()
  176. if self.scheduler is not None:
  177. self.scheduler.step()
  178. if train_metrics is None:
  179. raise RuntimeError("Training loader is empty.")
  180. train_metrics = self._average_metric_sums(train_metric_sums, num_steps)
  181. val_metrics = self._validate() if self._should_validate(epoch) else None
  182. summary, should_stop = self._finalize_epoch(
  183. epoch=epoch,
  184. train_metrics=train_metrics,
  185. val_metrics=val_metrics,
  186. checkpoint_state={
  187. "model": self.model.state_dict(),
  188. "optimizer": self.optimizer.state_dict(),
  189. "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None,
  190. },
  191. )
  192. print(summary)
  193. if should_stop:
  194. print({"epoch": epoch, "message": "early stopping triggered"})
  195. break
  196. finally:
  197. self._close_loggers()