from __future__ import annotations import time from typing import Any import torch from torch.utils.data import DataLoader from lib.modules import SegmentationModel2d from lib.tools import build_loss, build_optimizer, build_scheduler from .base import BaseTrainer class SupervisedSegmentationTrainer(BaseTrainer): def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None: super().__init__(cfg=cfg, args=args) self.model: SegmentationModel2d | None = None self.optimizer = None self.scheduler = None self.loader: DataLoader | None = None self.val_loader: DataLoader | None = None self.seg_loss = None def build(self) -> None: dataset_cfg = self.cfg["dataset"] model_cfg = self.cfg["model"] train_cfg = self.cfg["train"] self.model = SegmentationModel2d( num_classes=dataset_cfg["num_classes"], model_name=model_cfg["model_name"], load_weights=model_cfg.get("load_weights", False), decoder_channels=model_cfg.get("decoder_channels"), use_multiscale_features=model_cfg.get("use_multiscale_features", True), include_patch_embed=model_cfg.get("include_patch_embed", True), ).to(self.device) self.optimizer = build_optimizer(self.model, self.cfg["optimizer"]) self.scheduler = build_scheduler(self.optimizer, self.cfg.get("scheduler")) loss_cfg = self.cfg.get("loss") if loss_cfg is not None: self.seg_loss = build_loss(loss_cfg) self.loader = self._build_segmentation_loader( split=str(dataset_cfg.get("split", "train")), split_file=dataset_cfg.get("split_file"), batch_size=self._resolve_batch_size("batch_size", 4), shuffle=bool(train_cfg.get("shuffle", True)), augmentation_config=self.cfg.get("augmentation", {}).get("train"), ) self.val_loader = self._build_val_loader( batch_size=self._resolve_batch_size( "val_batch_size", int(train_cfg.get("batch_size", 4)), ), shuffle=False, ) self._maybe_resume( module_map={"model": self.model}, optimizer=self.optimizer, scheduler=self.scheduler, ) self._init_swanlab() def _compute_losses( self, image: torch.Tensor, mask: torch.Tensor, ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: if self.model is None: raise RuntimeError("Model is not initialized.") with torch.autocast(device_type=self.device.type, enabled=self._amp_enabled()): outputs = self.model(image) seg_logits = outputs["seg_logits"] if self.seg_loss is None: seg_loss = torch.nn.functional.binary_cross_entropy_with_logits(seg_logits, mask) else: seg_loss = self.seg_loss(seg_logits, mask) total_loss = seg_loss losses = { "total": total_loss, "seg": seg_loss, } return outputs, losses @staticmethod def _detach_metrics(losses: dict[str, torch.Tensor]) -> dict[str, float]: return {key: float(value.detach().cpu()) for key, value in losses.items()} def _validate(self) -> dict[str, float] | None: if self.model is None or self.val_loader is None: return None self.model.eval() metrics = self._build_validation_metrics() total = 0.0 seg = 0.0 steps = 0 with torch.no_grad(): for batch in self.val_loader: image = batch["image"].to(self.device) mask = batch["mask"].to(self.device) outputs, losses = self._compute_losses(image, mask) total += float(losses["total"].detach().cpu()) seg += float(losses["seg"].detach().cpu()) self._update_validation_metrics( metrics, logits=outputs["seg_logits"], target=mask, ) steps += 1 if steps == 0: return None val_metrics = { "total": total / steps, "seg": seg / steps, } val_metrics.update(self._compute_validation_metric_values(metrics)) return val_metrics def train(self) -> None: if self.model is None or self.loader is None or self.optimizer is None: raise RuntimeError("Trainer.build() must be called before train().") epochs = int(self.cfg["train"].get("epochs", 1)) accum_steps = self._accum_steps() try: self._print_training_setup( model_map={"model": self.model}, loader_map={"train": self.loader, "val": self.val_loader}, optimizer=self.optimizer, scheduler=self.scheduler, ) for epoch in range(self.start_epoch, epochs): self.model.train() self.optimizer.zero_grad() train_metric_sums = { "total": 0.0, "seg": 0.0, } train_metrics: dict[str, float] | None = None end_time = time.perf_counter() num_steps = len(self.loader) for step, batch in enumerate(self.loader, start=1): data_time = time.perf_counter() - end_time iter_start = time.perf_counter() image = batch["image"].to(self.device) mask = batch["mask"].to(self.device) _, losses = self._compute_losses(image, mask) scaled_total_loss = losses["total"] / accum_steps self.grad_scaler.scale(scaled_total_loss).backward() grad_norm = None should_step = (step % accum_steps == 0) or (step == num_steps) if should_step: if self._grad_clip_enabled(): self.grad_scaler.unscale_(self.optimizer) grad_norm = self._clip_gradients(self.model) self.grad_scaler.step(self.optimizer) self.grad_scaler.update() self.optimizer.zero_grad() train_metrics = self._detach_metrics(losses) if grad_norm is not None: train_metrics["grad_norm"] = grad_norm for key, value in train_metrics.items(): train_metric_sums.setdefault(key, 0.0) train_metric_sums[key] += value iter_time = time.perf_counter() - iter_start self._maybe_log_step( epoch=epoch, step=step, num_steps=num_steps, data_time=data_time, iter_time=iter_time, metrics=train_metrics, prefix="train", ) end_time = time.perf_counter() if self.scheduler is not None: self.scheduler.step() if train_metrics is None: raise RuntimeError("Training loader is empty.") train_metrics = self._average_metric_sums(train_metric_sums, num_steps) val_metrics = self._validate() if self._should_validate(epoch) else None summary, should_stop = self._finalize_epoch( epoch=epoch, train_metrics=train_metrics, val_metrics=val_metrics, checkpoint_state={ "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict() if self.scheduler is not None else None, }, ) print(summary) if should_stop: print({"epoch": epoch, "message": "early stopping triggered"}) break finally: self._close_loggers()