supervised.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. from __future__ import annotations
  2. import time
  3. from typing import Any
  4. import torch
  5. from torch.utils.data import DataLoader
  6. from lib.modules import XNet2d
  7. from lib.tools import build_loss, build_optimizer, build_scheduler
  8. from .base import BaseTrainer
  9. class SupervisedSegmentationTrainer(BaseTrainer):
  10. def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
  11. super().__init__(cfg=cfg, args=args)
  12. self.model: XNet2d | None = None
  13. self.optimizer = None
  14. self.scheduler = None
  15. self.loader: DataLoader | None = None
  16. self.val_loader: DataLoader | None = None
  17. self.seg_loss = None
  18. def build(self) -> None:
  19. dataset_cfg = self.cfg["dataset"]
  20. model_cfg = self.cfg["model"]
  21. train_cfg = self.cfg["train"]
  22. self.model = XNet2d(
  23. in_channels=int(model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3))),
  24. num_classes=int(dataset_cfg["num_classes"]),
  25. encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))),
  26. encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))),
  27. decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))),
  28. stem_channels=int(model_cfg.get("stem_channels", 24)),
  29. bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)),
  30. global_ratio=float(model_cfg.get("global_ratio", 2.0)),
  31. wavelet_type=str(model_cfg.get("wavelet_type", "haar")),
  32. wavelet_level=int(model_cfg.get("wavelet_level", 1)),
  33. use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)),
  34. use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)),
  35. ssm_d_state=int(model_cfg.get("ssm_d_state", 16)),
  36. ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
  37. ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
  38. use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
  39. low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)),
  40. low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)),
  41. learnable_low_freq_radius=bool(
  42. model_cfg.get("learnable_low_freq_radius", True)
  43. ),
  44. guide_mode=str(model_cfg.get("guide_mode", "affine")),
  45. out_channels=model_cfg.get("out_channels"),
  46. ).to(self.device)
  47. self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
  48. self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
  49. loss_cfg = self.cfg.get("loss")
  50. if loss_cfg is not None:
  51. self.seg_loss = build_loss(loss_cfg)
  52. self.loader = self._build_segmentation_loader(
  53. split=str(dataset_cfg.get("split", "train")),
  54. split_file=dataset_cfg.get("split_file"),
  55. batch_size=self._resolve_batch_size("batch_size", 4),
  56. shuffle=bool(train_cfg.get("shuffle", True)),
  57. augmentation_config=self.cfg.get("augmentation", {}).get("train"),
  58. )
  59. self.val_loader = self._build_val_loader(
  60. batch_size=self._resolve_batch_size(
  61. "val_batch_size",
  62. int(train_cfg.get("batch_size", 4)),
  63. ),
  64. shuffle=False,
  65. )
  66. self._maybe_resume(
  67. module_map={"model": self.model},
  68. optimizer=self.optimizer,
  69. scheduler=self.scheduler,
  70. )
  71. self._init_swanlab()
  72. def _compute_losses(
  73. self,
  74. image: torch.Tensor,
  75. mask: torch.Tensor,
  76. ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
  77. if self.model is None:
  78. raise RuntimeError("Model is not initialized.")
  79. with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
  80. outputs = self.model(image)
  81. seg_logits = outputs["seg_logits"]
  82. if self.seg_loss is None:
  83. seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
  84. else:
  85. seg_loss = self.seg_loss(seg_logits, mask)
  86. total_loss = seg_loss
  87. losses = {
  88. "total": total_loss,
  89. "seg": seg_loss,
  90. }
  91. return outputs, losses
  92. @staticmethod
  93. def _detach_metrics(losses: dict[str, torch.Tensor]) -> dict[str, float]:
  94. return {key: float(value.detach().cpu()) for key, value in losses.items()}
  95. def _validate(self) -> dict[str, float] | None:
  96. if self.model is None or self.val_loader is None:
  97. return None
  98. self.model.eval()
  99. metrics = self._build_validation_metrics()
  100. total = 0.0
  101. seg = 0.0
  102. steps = 0
  103. with torch.no_grad():
  104. for batch in self.val_loader:
  105. image = batch["image"].to(self.device)
  106. mask = batch["mask"].to(self.device)
  107. outputs, losses = self._compute_losses(image, mask)
  108. total += float(losses["total"].detach().cpu())
  109. seg += float(losses["seg"].detach().cpu())
  110. self._update_validation_metrics(
  111. metrics,
  112. logits=outputs["seg_logits"],
  113. target=mask,
  114. )
  115. steps += 1
  116. if steps == 0:
  117. return None
  118. val_metrics = {
  119. "total": total / steps,
  120. "seg": seg / steps,
  121. }
  122. val_metrics.update(self._compute_validation_metric_values(metrics))
  123. return val_metrics
  124. def train(self) -> None:
  125. if self.model is None or self.loader is None or self.optimizer is None:
  126. raise RuntimeError("Trainer.build() must be called before train().")
  127. epochs = int(self.cfg["train"].get("epochs", 1))
  128. accum_steps = self._accum_steps()
  129. try:
  130. self._print_training_setup(
  131. model_map={"model": self.model},
  132. loader_map={"train": self.loader, "val": self.val_loader},
  133. optimizer=self.optimizer,
  134. scheduler=self.scheduler,
  135. )
  136. for epoch in range(self.start_epoch, epochs):
  137. self.model.train()
  138. self.optimizer.zero_grad()
  139. train_metric_sums = {
  140. "total": 0.0,
  141. "seg": 0.0,
  142. }
  143. train_metrics: dict[str, float] | None = None
  144. end_time = time.perf_counter()
  145. num_steps = len(self.loader)
  146. for step, batch in enumerate(self.loader, start=1):
  147. data_time = time.perf_counter() - end_time
  148. iter_start = time.perf_counter()
  149. image = batch["image"].to(self.device)
  150. mask = batch["mask"].to(self.device)
  151. _, losses = self._compute_losses(image, mask)
  152. scaled_total_loss = losses["total"] / accum_steps
  153. self.grad_scaler.scale(scaled_total_loss).backward()
  154. grad_norm = None
  155. should_step = (step % accum_steps == 0) or (step == num_steps)
  156. if should_step:
  157. if self._grad_clip_enabled():
  158. self.grad_scaler.unscale_(self.optimizer)
  159. grad_norm = self._clip_gradients(self.model)
  160. self.grad_scaler.step(self.optimizer)
  161. self.grad_scaler.update()
  162. self.optimizer.zero_grad()
  163. train_metrics = self._detach_metrics(losses)
  164. if grad_norm is not None:
  165. train_metrics["grad_norm"] = grad_norm
  166. for key, value in train_metrics.items():
  167. train_metric_sums.setdefault(key, 0.0)
  168. train_metric_sums[key] += value
  169. iter_time = time.perf_counter() - iter_start
  170. self._maybe_log_step(
  171. epoch=epoch,
  172. step=step,
  173. num_steps=num_steps,
  174. data_time=data_time,
  175. iter_time=iter_time,
  176. metrics=train_metrics,
  177. prefix="train",
  178. )
  179. end_time = time.perf_counter()
  180. if self.scheduler is not None:
  181. self.scheduler.step()
  182. if train_metrics is None:
  183. raise RuntimeError("Training loader is empty.")
  184. train_metrics = self._average_metric_sums(train_metric_sums, num_steps)
  185. val_metrics = self._validate() if self._should_validate(epoch) else None
  186. summary, should_stop = self._finalize_epoch(
  187. epoch=epoch,
  188. train_metrics=train_metrics,
  189. val_metrics=val_metrics,
  190. checkpoint_state={
  191. "model": self.model.state_dict(),
  192. "optimizer": self.optimizer.state_dict(),
  193. "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None,
  194. },
  195. )
  196. print(summary)
  197. if should_stop:
  198. print({"epoch": epoch, "message": "early stopping triggered"})
  199. break
  200. finally:
  201. self._close_loggers()