supervised.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from __future__ import annotations
  2. import time
  3. from typing import Any
  4. import torch
  5. from torch.utils.data import DataLoader
  6. from lib.modules import SegmentationModel2d
  7. from lib.tools import build_loss, build_optimizer, build_scheduler
  8. from .base import BaseTrainer
  9. class SupervisedSegmentationTrainer(BaseTrainer):
  10. def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
  11. super().__init__(cfg=cfg, args=args)
  12. self.model: SegmentationModel2d | None = None
  13. self.optimizer = None
  14. self.scheduler = None
  15. self.loader: DataLoader | None = None
  16. self.val_loader: DataLoader | None = None
  17. self.seg_loss = None
  18. def build(self) -> None:
  19. dataset_cfg = self.cfg["dataset"]
  20. model_cfg = self.cfg["model"]
  21. train_cfg = self.cfg["train"]
  22. self.model = SegmentationModel2d(
  23. num_classes=dataset_cfg["num_classes"],
  24. model_name=model_cfg["model_name"],
  25. load_weights=model_cfg.get("load_weights", False),
  26. decoder_channels=model_cfg.get("decoder_channels"),
  27. use_multiscale_features=model_cfg.get("use_multiscale_features", True),
  28. include_patch_embed=model_cfg.get("include_patch_embed", True),
  29. ).to(self.device)
  30. self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
  31. self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
  32. loss_cfg = self.cfg.get("loss")
  33. if loss_cfg is not None:
  34. self.seg_loss = build_loss(loss_cfg)
  35. self.loader = self._build_segmentation_loader(
  36. split=str(dataset_cfg.get("split", "train")),
  37. split_file=dataset_cfg.get("split_file"),
  38. batch_size=self._resolve_batch_size("batch_size", 4),
  39. shuffle=bool(train_cfg.get("shuffle", True)),
  40. augmentation_config=self.cfg.get("augmentation", {}).get("train"),
  41. )
  42. self.val_loader = self._build_val_loader(
  43. batch_size=self._resolve_batch_size(
  44. "val_batch_size",
  45. int(train_cfg.get("batch_size", 4)),
  46. ),
  47. shuffle=False,
  48. )
  49. self._maybe_resume(
  50. module_map={"model": self.model},
  51. optimizer=self.optimizer,
  52. scheduler=self.scheduler,
  53. )
  54. self._init_swanlab()
  55. def _compute_losses(
  56. self,
  57. image: torch.Tensor,
  58. mask: torch.Tensor,
  59. ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
  60. if self.model is None:
  61. raise RuntimeError("Model is not initialized.")
  62. with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
  63. outputs = self.model(image)
  64. seg_logits = outputs["seg_logits"]
  65. if self.seg_loss is None:
  66. seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
  67. else:
  68. seg_loss = self.seg_loss(seg_logits, mask)
  69. total_loss = seg_loss
  70. losses = {
  71. "total": total_loss,
  72. "seg": seg_loss,
  73. }
  74. return outputs, losses
  75. @staticmethod
  76. def _detach_metrics(losses: dict[str, torch.Tensor]) -> dict[str, float]:
  77. return {key: float(value.detach().cpu()) for key, value in losses.items()}
  78. def _validate(self) -> dict[str, float] | None:
  79. if self.model is None or self.val_loader is None:
  80. return None
  81. self.model.eval()
  82. metrics = self._build_validation_metrics()
  83. total = 0.0
  84. seg = 0.0
  85. steps = 0
  86. with torch.no_grad():
  87. for batch in self.val_loader:
  88. image = batch["image"].to(self.device)
  89. mask = batch["mask"].to(self.device)
  90. outputs, losses = self._compute_losses(image, mask)
  91. total += float(losses["total"].detach().cpu())
  92. seg += float(losses["seg"].detach().cpu())
  93. self._update_validation_metrics(
  94. metrics,
  95. logits=outputs["seg_logits"],
  96. target=mask,
  97. )
  98. steps += 1
  99. if steps == 0:
  100. return None
  101. val_metrics = {
  102. "total": total / steps,
  103. "seg": seg / steps,
  104. }
  105. val_metrics.update(self._compute_validation_metric_values(metrics))
  106. return val_metrics
  107. def train(self) -> None:
  108. if self.model is None or self.loader is None or self.optimizer is None:
  109. raise RuntimeError("Trainer.build() must be called before train().")
  110. epochs = int(self.cfg["train"].get("epochs", 1))
  111. accum_steps = self._accum_steps()
  112. try:
  113. self._print_training_setup(
  114. model_map={"model": self.model},
  115. loader_map={"train": self.loader, "val": self.val_loader},
  116. optimizer=self.optimizer,
  117. scheduler=self.scheduler,
  118. )
  119. for epoch in range(self.start_epoch, epochs):
  120. self.model.train()
  121. self.optimizer.zero_grad()
  122. train_metric_sums = {
  123. "total": 0.0,
  124. "seg": 0.0,
  125. }
  126. train_metrics: dict[str, float] | None = None
  127. end_time = time.perf_counter()
  128. num_steps = len(self.loader)
  129. for step, batch in enumerate(self.loader, start=1):
  130. data_time = time.perf_counter() - end_time
  131. iter_start = time.perf_counter()
  132. image = batch["image"].to(self.device)
  133. mask = batch["mask"].to(self.device)
  134. _, losses = self._compute_losses(image, mask)
  135. scaled_total_loss = losses["total"] / accum_steps
  136. self.grad_scaler.scale(scaled_total_loss).backward()
  137. grad_norm = None
  138. should_step = (step % accum_steps == 0) or (step == num_steps)
  139. if should_step:
  140. if self._grad_clip_enabled():
  141. self.grad_scaler.unscale_(self.optimizer)
  142. grad_norm = self._clip_gradients(self.model)
  143. self.grad_scaler.step(self.optimizer)
  144. self.grad_scaler.update()
  145. self.optimizer.zero_grad()
  146. train_metrics = self._detach_metrics(losses)
  147. if grad_norm is not None:
  148. train_metrics["grad_norm"] = grad_norm
  149. for key, value in train_metrics.items():
  150. train_metric_sums.setdefault(key, 0.0)
  151. train_metric_sums[key] += value
  152. iter_time = time.perf_counter() - iter_start
  153. self._maybe_log_step(
  154. epoch=epoch,
  155. step=step,
  156. num_steps=num_steps,
  157. data_time=data_time,
  158. iter_time=iter_time,
  159. metrics=train_metrics,
  160. prefix="train",
  161. )
  162. end_time = time.perf_counter()
  163. if self.scheduler is not None:
  164. self.scheduler.step()
  165. if train_metrics is None:
  166. raise RuntimeError("Training loader is empty.")
  167. train_metrics = self._average_metric_sums(train_metric_sums, num_steps)
  168. val_metrics = self._validate() if self._should_validate(epoch) else None
  169. summary, should_stop = self._finalize_epoch(
  170. epoch=epoch,
  171. train_metrics=train_metrics,
  172. val_metrics=val_metrics,
  173. checkpoint_state={
  174. "model": self.model.state_dict(),
  175. "optimizer": self.optimizer.state_dict(),
  176. "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None,
  177. },
  178. )
  179. print(summary)
  180. if should_stop:
  181. print({"epoch": epoch, "message": "early stopping triggered"})
  182. break
  183. finally:
  184. self._close_loggers()