| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221 |
- from __future__ import annotations
- import time
- from typing import Any
- import torch
- from torch.utils.data import DataLoader
- from lib.modules import XNet2d
- from lib.tools import build_loss, build_optimizer, build_scheduler
- from .base import BaseTrainer
- class SupervisedSegmentationTrainer(BaseTrainer):
- def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
- super().__init__(cfg=cfg, args=args)
- self.model: XNet2d | None = None
- self.optimizer = None
- self.scheduler = None
- self.loader: DataLoader | None = None
- self.val_loader: DataLoader | None = None
- self.seg_loss = None
- def build(self) -> None:
- dataset_cfg = self.cfg["dataset"]
- model_cfg = self.cfg["model"]
- train_cfg = self.cfg["train"]
- self.model = XNet2d(
- in_channels=int(model_cfg.get("in_channels", dataset_cfg.get("in_channels", 3))),
- num_classes=int(dataset_cfg["num_classes"]),
- encoder_channels=tuple(model_cfg.get("encoder_channels", (32, 64, 128, 192))),
- encoder_depths=tuple(model_cfg.get("encoder_depths", (2, 2, 2, 2))),
- decoder_channels=tuple(model_cfg.get("decoder_channels", (128, 64, 32))),
- stem_channels=int(model_cfg.get("stem_channels", 24)),
- bottleneck_depth=int(model_cfg.get("bottleneck_depth", 1)),
- global_ratio=float(model_cfg.get("global_ratio", 2.0)),
- wavelet_type=str(model_cfg.get("wavelet_type", "haar")),
- wavelet_level=int(model_cfg.get("wavelet_level", 1)),
- use_wavelet_branch=bool(model_cfg.get("use_wavelet_branch", True)),
- use_global_branch_stage1=bool(model_cfg.get("use_global_branch_stage1", False)),
- ssm_d_state=int(model_cfg.get("ssm_d_state", 16)),
- ssm_forward_type=str(model_cfg.get("ssm_forward_type", "v3")),
- ssm_backend=str(model_cfg.get("ssm_backend", "auto")),
- use_frequency_refine=bool(model_cfg.get("use_frequency_refine", True)),
- low_freq_radius_h=float(model_cfg.get("low_freq_radius_h", 0.25)),
- low_freq_radius_w=float(model_cfg.get("low_freq_radius_w", 0.25)),
- learnable_low_freq_radius=bool(
- model_cfg.get("learnable_low_freq_radius", True)
- ),
- guide_mode=str(model_cfg.get("guide_mode", "affine")),
- out_channels=model_cfg.get("out_channels"),
- ).to(self.device)
- self.optimizer = build_optimizer(self.model, self.cfg["optimizer"])
- self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler"))
- loss_cfg = self.cfg.get("loss")
- if loss_cfg is not None:
- self.seg_loss = build_loss(loss_cfg)
- self.loader = self._build_segmentation_loader(
- split=str(dataset_cfg.get("split", "train")),
- split_file=dataset_cfg.get("split_file"),
- batch_size=self._resolve_batch_size("batch_size", 4),
- shuffle=bool(train_cfg.get("shuffle", True)),
- augmentation_config=self.cfg.get("augmentation", {}).get("train"),
- )
- self.val_loader = self._build_val_loader(
- batch_size=self._resolve_batch_size(
- "val_batch_size",
- int(train_cfg.get("batch_size", 4)),
- ),
- shuffle=False,
- )
- self._maybe_resume(
- module_map={"model": self.model},
- optimizer=self.optimizer,
- scheduler=self.scheduler,
- )
- self._init_swanlab()
- def _compute_losses(
- self,
- image: torch.Tensor,
- mask: torch.Tensor,
- ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
- if self.model is None:
- raise RuntimeError("Model is not initialized.")
- with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()):
- outputs = self.model(image)
- seg_logits = outputs["seg_logits"]
- if self.seg_loss is None:
- seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask)
- else:
- seg_loss = self.seg_loss(seg_logits, mask)
- total_loss = seg_loss
- losses = {
- "total": total_loss,
- "seg": seg_loss,
- }
- return outputs, losses
- @staticmethod
- def _detach_metrics(losses: dict[str, torch.Tensor]) -> dict[str, float]:
- return {key: float(value.detach().cpu()) for key, value in losses.items()}
- def _validate(self) -> dict[str, float] | None:
- if self.model is None or self.val_loader is None:
- return None
- self.model.eval()
- metrics = self._build_validation_metrics()
- total = 0.0
- seg = 0.0
- steps = 0
- with torch.no_grad():
- for batch in self.val_loader:
- image = batch["image"].to(self.device)
- mask = batch["mask"].to(self.device)
- outputs, losses = self._compute_losses(image, mask)
- total += float(losses["total"].detach().cpu())
- seg += float(losses["seg"].detach().cpu())
- self._update_validation_metrics(
- metrics,
- logits=outputs["seg_logits"],
- target=mask,
- )
- steps += 1
- if steps == 0:
- return None
- val_metrics = {
- "total": total / steps,
- "seg": seg / steps,
- }
- val_metrics.update(self._compute_validation_metric_values(metrics))
- return val_metrics
- def train(self) -> None:
- if self.model is None or self.loader is None or self.optimizer is None:
- raise RuntimeError("Trainer.build() must be called before train().")
- epochs = int(self.cfg["train"].get("epochs", 1))
- accum_steps = self._accum_steps()
- try:
- self._print_training_setup(
- model_map={"model": self.model},
- loader_map={"train": self.loader, "val": self.val_loader},
- optimizer=self.optimizer,
- scheduler=self.scheduler,
- )
- for epoch in range(self.start_epoch, epochs):
- self.model.train()
- self.optimizer.zero_grad()
- train_metric_sums = {
- "total": 0.0,
- "seg": 0.0,
- }
- train_metrics: dict[str, float] | None = None
- end_time = time.perf_counter()
- num_steps = len(self.loader)
- for step, batch in enumerate(self.loader, start=1):
- data_time = time.perf_counter() - end_time
- iter_start = time.perf_counter()
- image = batch["image"].to(self.device)
- mask = batch["mask"].to(self.device)
- _, losses = self._compute_losses(image, mask)
- scaled_total_loss = losses["total"] / accum_steps
- self.grad_scaler.scale(scaled_total_loss).backward()
- grad_norm = None
- should_step = (step % accum_steps == 0) or (step == num_steps)
- if should_step:
- if self._grad_clip_enabled():
- self.grad_scaler.unscale_(self.optimizer)
- grad_norm = self._clip_gradients(self.model)
- self.grad_scaler.step(self.optimizer)
- self.grad_scaler.update()
- self.optimizer.zero_grad()
- train_metrics = self._detach_metrics(losses)
- if grad_norm is not None:
- train_metrics["grad_norm"] = grad_norm
- for key, value in train_metrics.items():
- train_metric_sums.setdefault(key, 0.0)
- train_metric_sums[key] += value
- iter_time = time.perf_counter() - iter_start
- self._maybe_log_step(
- epoch=epoch,
- step=step,
- num_steps=num_steps,
- data_time=data_time,
- iter_time=iter_time,
- metrics=train_metrics,
- prefix="train",
- )
- end_time = time.perf_counter()
- if self.scheduler is not None:
- self.scheduler.step()
- if train_metrics is None:
- raise RuntimeError("Training loader is empty.")
- train_metrics = self._average_metric_sums(train_metric_sums, num_steps)
- val_metrics = self._validate() if self._should_validate(epoch) else None
- summary, should_stop = self._finalize_epoch(
- epoch=epoch,
- train_metrics=train_metrics,
- val_metrics=val_metrics,
- checkpoint_state={
- "model": self.model.state_dict(),
- "optimizer": self.optimizer.state_dict(),
- "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None,
- },
- )
- print(summary)
- if should_stop:
- print({"epoch": epoch, "message": "early stopping triggered"})
- break
- finally:
- self._close_loggers()
|