supervised.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. from __future__ import annotations
  2. import time
  3. from typing import Any
  4. import torch
  5. from torch.utils.data import DataLoader
  6. from lib.modules import XNet2d
  7. from lib.tools import build_loss, build_optimizer, build_scheduler
  8. from .base import BaseTrainer
  9. class SupervisedSegmentationTrainer(BaseTrainer):
  10. def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
  11. super().__init__(cfg=cfg, args=args)
  12. self.model: XNet2d | None = None
  13. self.optimizer = None
  14. self.scheduler = None
  15. self.loader: DataLoader | None = None
  16. self.val_loader: DataLoader | None = None
  17. self.seg_loss = None
  18. def build(self) -> None:
  19. dataset_cfg = self.cfg["dataset"]
  20. model_cfg = self.cfg["model"]
  21. train_cfg = self.cfg["train"]
  22. self.model = XNet2d(
  23. in_channels=int(model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3))),
  24. num_classes=int(dataset_cfg["num_classes"]),
  25. encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))),
  26. encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))),
  27. decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))),
  28. stem_channels=int(model_cfg.get("stem_channels", 24)),
  29. bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)),
  30. global_ratio=float(model_cfg.get("global_ratio", 2.0)),
  31. wavelet_type=str(model_cfg.get("wavelet_type", "haar")),
  32. wavelet_level=int(model_cfg.get("wavelet_level", 1)),
  33. use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)),
  34. use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)),
  35. ssm_d_state=int(model_cfg.get("ssm_d_state", 16)),
  36. ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
  37. ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
  38. use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
  39. guide_mode=str(model_cfg.get("guide_mode", "affine")),
  40. out_channels=model_cfg.get("out_channels"),
  41. ).to(self.device)
  42. self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
  43. self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
  44. loss_cfg = self.cfg.get("loss")
  45. if loss_cfg is not None:
  46. self.seg_loss = build_loss(loss_cfg)
  47. self.loader = self._build_segmentation_loader(
  48. split=str(dataset_cfg.get("split", "train")),
  49. split_file=dataset_cfg.get("split_file"),
  50. batch_size=self._resolve_batch_size("batch_size", 4),
  51. shuffle=bool(train_cfg.get("shuffle", True)),
  52. augmentation_config=self.cfg.get("augmentation", {}).get("train"),
  53. )
  54. self.val_loader = self._build_val_loader(
  55. batch_size=self._resolve_batch_size(
  56. "val_batch_size",
  57. int(train_cfg.get("batch_size", 4)),
  58. ),
  59. shuffle=False,
  60. )
  61. self._maybe_resume(
  62. module_map={"model": self.model},
  63. optimizer=self.optimizer,
  64. scheduler=self.scheduler,
  65. )
  66. self._init_swanlab()
  67. def _compute_losses(
  68. self,
  69. image: torch.Tensor,
  70. mask: torch.Tensor,
  71. ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
  72. if self.model is None:
  73. raise RuntimeError("Model is not initialized.")
  74. with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
  75. outputs = self.model(image)
  76. seg_logits = outputs["seg_logits"]
  77. if self.seg_loss is None:
  78. seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
  79. else:
  80. seg_loss = self.seg_loss(seg_logits, mask)
  81. total_loss = seg_loss
  82. losses = {
  83. "total": total_loss,
  84. "seg": seg_loss,
  85. }
  86. return outputs, losses
  87. @staticmethod
  88. def _detach_metrics(losses: dict[str, torch.Tensor]) -> dict[str, float]:
  89. return {key: float(value.detach().cpu()) for key, value in losses.items()}
  90. def _validate(self) -> dict[str, float] | None:
  91. if self.model is None or self.val_loader is None:
  92. return None
  93. self.model.eval()
  94. metrics = self._build_validation_metrics()
  95. total = 0.0
  96. seg = 0.0
  97. steps = 0
  98. with torch.no_grad():
  99. for batch in self.val_loader:
  100. image = batch["image"].to(self.device)
  101. mask = batch["mask"].to(self.device)
  102. outputs, losses = self._compute_losses(image, mask)
  103. total += float(losses["total"].detach().cpu())
  104. seg += float(losses["seg"].detach().cpu())
  105. self._update_validation_metrics(
  106. metrics,
  107. logits=outputs["seg_logits"],
  108. target=mask,
  109. )
  110. steps += 1
  111. if steps == 0:
  112. return None
  113. val_metrics = {
  114. "total": total / steps,
  115. "seg": seg / steps,
  116. }
  117. val_metrics.update(self._compute_validation_metric_values(metrics))
  118. return val_metrics
  119. def train(self) -> None:
  120. if self.model is None or self.loader is None or self.optimizer is None:
  121. raise RuntimeError("Trainer.build() must be called before train().")
  122. epochs = int(self.cfg["train"].get("epochs", 1))
  123. accum_steps = self._accum_steps()
  124. try:
  125. self._print_training_setup(
  126. model_map={"model": self.model},
  127. loader_map={"train": self.loader, "val": self.val_loader},
  128. optimizer=self.optimizer,
  129. scheduler=self.scheduler,
  130. )
  131. for epoch in range(self.start_epoch, epochs):
  132. self.model.train()
  133. self.optimizer.zero_grad()
  134. train_metric_sums = {
  135. "total": 0.0,
  136. "seg": 0.0,
  137. }
  138. train_metrics: dict[str, float] | None = None
  139. end_time = time.perf_counter()
  140. num_steps = len(self.loader)
  141. for step, batch in enumerate(self.loader, start=1):
  142. data_time = time.perf_counter() - end_time
  143. iter_start = time.perf_counter()
  144. image = batch["image"].to(self.device)
  145. mask = batch["mask"].to(self.device)
  146. _, losses = self._compute_losses(image, mask)
  147. scaled_total_loss = losses["total"] / accum_steps
  148. self.grad_scaler.scale(scaled_total_loss).backward()
  149. grad_norm = None
  150. should_step = (step % accum_steps == 0) or (step == num_steps)
  151. if should_step:
  152. if self._grad_clip_enabled():
  153. self.grad_scaler.unscale_(self.optimizer)
  154. grad_norm = self._clip_gradients(self.model)
  155. self.grad_scaler.step(self.optimizer)
  156. self.grad_scaler.update()
  157. self.optimizer.zero_grad()
  158. train_metrics = self._detach_metrics(losses)
  159. if grad_norm is not None:
  160. train_metrics["grad_norm"] = grad_norm
  161. for key, value in train_metrics.items():
  162. train_metric_sums.setdefault(key, 0.0)
  163. train_metric_sums[key] += value
  164. iter_time = time.perf_counter() - iter_start
  165. self._maybe_log_step(
  166. epoch=epoch,
  167. step=step,
  168. num_steps=num_steps,
  169. data_time=data_time,
  170. iter_time=iter_time,
  171. metrics=train_metrics,
  172. prefix="train",
  173. )
  174. end_time = time.perf_counter()
  175. if self.scheduler is not None:
  176. self.scheduler.step()
  177. if train_metrics is None:
  178. raise RuntimeError("Training loader is empty.")
  179. train_metrics = self._average_metric_sums(train_metric_sums, num_steps)
  180. val_metrics = self._validate() if self._should_validate(epoch) else None
  181. summary, should_stop = self._finalize_epoch(
  182. epoch=epoch,
  183. train_metrics=train_metrics,
  184. val_metrics=val_metrics,
  185. checkpoint_state={
  186. "model": self.model.state_dict(),
  187. "optimizer": self.optimizer.state_dict(),
  188. "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None,
  189. },
  190. )
  191. print(summary)
  192. if should_stop:
  193. print({"epoch": epoch, "message": "early stopping triggered"})
  194. break
  195. finally:
  196. self._close_loggers()