from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path import pprint import time from typing import Any import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import GradScaler from lib.data import build_dataloader from lib.tools import build_metrics, compute_metrics, reset_metrics, update_metrics try: import swanlab except ImportError: swanlab = None class BaseTrainer(ABC): """ 训练器基类。 设计目标: - 统一配置入口 - 统一模型/优化器/调度器创建 - 不同训练流程只重写最少的方法 """ def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None: self.cfg = cfg self.args = args self.device = self._build_device() self.output_dir = self._build_output_dir() self.start_epoch = 0 self.best_metric: float | None = None self.no_improve_epochs = 0 self.swanlab_run = None self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled()) def _build_device(self) -> torch.device: device_name = self.cfg.get("train", {}).get("device", "cpu") if device_name == "cuda" and not torch.cuda.is_available(): device_name = "cpu" return torch.device(device_name) def _build_output_dir(self) -> Path: output_dir = self.cfg.get("checkpoint", {}).get("dir", "outputs/supervised_segmentation") path = Path(output_dir) path.mkdir(parents=True, exist_ok=True) return path def _amp_enabled(self) -> bool: return bool(self.cfg.get("train", {}).get("amp", False)) and self.device.type == "cuda" def _auto_batch_size_cfg(self) -> dict[str, Any]: cfg = self.cfg.get("train", {}).get("auto_batch_size", {}) return cfg if isinstance(cfg, dict) else {} def _auto_batch_size_enabled(self) -> bool: return bool(self._auto_batch_size_cfg().get("enabled", False)) def _gpu_total_memory_gb(self) -> float | None: if self.device.type != "cuda" or not torch.cuda.is_available(): return None props = torch.cuda.get_device_properties(self.device) return float(props.total_memory / (1024 ** 3)) def _estimate_auto_batch_size(self, *, default_batch_size: int, ssl: bool = False) -> int: cfg = self._auto_batch_size_cfg() if not cfg.get("enabled", False): return int(default_batch_size) total_gb = self._gpu_total_memory_gb() if total_gb is None: return int(default_batch_size) target_fraction = float(cfg.get("target_memory_fraction", 0.75)) target_fraction = min(max(target_fraction, 0.1), 0.95) reference_gpu_gb = float(cfg.get("reference_gpu_gb", 8.0)) reference_batch_size = int(cfg.get("reference_batch_size", default_batch_size)) max_batch_size = int(cfg.get("max_batch_size", reference_batch_size)) min_batch_size = int(cfg.get("min_batch_size", 1)) memory_penalty = float(cfg.get("memory_penalty", 1.0 if not ssl else 1.35)) scaled = int((reference_batch_size * total_gb * target_fraction) / max(reference_gpu_gb * 0.75 * memory_penalty, 1e-6)) batch_size = max(min_batch_size, min(max_batch_size, max(default_batch_size, scaled))) return int(batch_size) def _resolve_batch_size(self, key: str, default: int, *, ssl: bool = False) -> int: train_cfg = self.cfg.get("train", {}) configured = int(train_cfg.get(key, default)) batch_size = self._estimate_auto_batch_size(default_batch_size=configured, ssl=ssl) if self._auto_batch_size_enabled() and batch_size != configured: print( { "message": "auto_batch_size adjusted", "key": key, "configured": configured, "resolved": batch_size, "gpu_total_gb": self._gpu_total_memory_gb(), } ) return batch_size def _dataset_cfg(self) -> dict[str, Any]: return self.cfg.get("dataset", {}) def _dataset_name(self) -> str: dataset_cfg = self._dataset_cfg() dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name") if not dataset_name: raise ValueError("dataset.dataset_name is required.") return str(dataset_name) def _dataset_root(self) -> str: dataset_cfg = self._dataset_cfg() root = dataset_cfg.get("root") if not root: raise ValueError("dataset.root is required.") return str(root) def _image_size(self) -> tuple[int, int]: dataset_cfg = self._dataset_cfg() image_size = dataset_cfg.get("image_size") if image_size is None: raise ValueError("dataset.image_size is required.") return int(image_size[0]), int(image_size[1]) def _build_resize_transform(self, *, mode: str) -> Any: height, width = self._image_size() interpolation_mode = "bilinear" if mode == "image" else "nearest" def _transform(tensor: torch.Tensor) -> torch.Tensor: resized = F.interpolate( tensor.unsqueeze(0), size=(height, width), mode=interpolation_mode, align_corners=False if interpolation_mode != "nearest" else None, ) return resized.squeeze(0) return _transform def _build_segmentation_loader( self, *, split: str, batch_size: int, shuffle: bool, split_file: str | None = None, ): dataset_cfg = self._dataset_cfg() train_cfg = self.cfg.get("train", {}) num_workers = max(0, int(train_cfg.get("num_workers", 0))) persistent_workers = bool(train_cfg.get("persistent_workers", False)) if num_workers > 0 else False loader = build_dataloader( dataset_name=self._dataset_name(), root=self._dataset_root(), split=split, split_file=split_file, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, image_transform=self._build_resize_transform(mode="image"), mask_transform=self._build_resize_transform(mode="mask"), pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")), persistent_workers=persistent_workers, prefetch_factor=train_cfg.get("prefetch_factor") if num_workers > 0 else None, ) return loader def _build_val_loader( self, *, batch_size: int, shuffle: bool = False, ): dataset_cfg = self._dataset_cfg() val_split = dataset_cfg.get("val_split", "val") if val_split is None: return None return self._build_segmentation_loader( split=str(val_split), split_file=dataset_cfg.get("val_split_file"), batch_size=batch_size, shuffle=shuffle, ) def _checkpoint_cfg(self) -> dict[str, Any]: return self.cfg.get("checkpoint", {}) def _logging_cfg(self) -> dict[str, Any]: return self.cfg.get("logging", {}) def _validation_cfg(self) -> dict[str, Any]: return self.cfg.get("validation", {}) def _checkpoint_enabled(self) -> bool: return bool(self._checkpoint_cfg().get("save", True)) def _best_mode(self) -> str: return str(self._checkpoint_cfg().get("monitor_mode", "min")) def _is_better_metric(self, metric: float) -> bool: if self.best_metric is None: return True if self._best_mode() == "max": return metric > self.best_metric return metric < self.best_metric def _save_checkpoint(self, filename: str, state: dict[str, Any]) -> Path | None: if not self._checkpoint_enabled(): return None path = self.output_dir / filename torch.save(state, path) return path def _resume_checkpoint_path(self) -> Path | None: resume_path = self._checkpoint_cfg().get("resume") if not resume_path: return None path = Path(str(resume_path)) if not path.is_absolute(): path = Path.cwd() / path return path def _maybe_resume( self, *, module_map: dict[str, Any], optimizer: Any | None = None, scheduler: Any | None = None, ) -> dict[str, Any] | None: path = self._resume_checkpoint_path() if path is None: return None if not path.exists(): raise FileNotFoundError(f"Resume checkpoint not found: {path}") checkpoint = torch.load(path, map_location="cpu") strict = bool(self._checkpoint_cfg().get("resume_strict", True)) for key, module in module_map.items(): if module is None: continue state_dict = checkpoint.get(key) if state_dict is not None: module.load_state_dict(state_dict, strict=strict) if optimizer is not None and checkpoint.get("optimizer") is not None: optimizer.load_state_dict(checkpoint["optimizer"]) if scheduler is not None and checkpoint.get("scheduler") is not None: scheduler.load_state_dict(checkpoint["scheduler"]) if checkpoint.get("grad_scaler") is not None: self.grad_scaler.load_state_dict(checkpoint["grad_scaler"]) if checkpoint.get("best_metric") is not None: self.best_metric = float(checkpoint["best_metric"]) elif checkpoint.get("metrics") is not None: monitor_name = str(self._checkpoint_cfg().get("monitor", "total")) monitor_value = checkpoint["metrics"].get(f"val_{monitor_name}") if monitor_value is None: monitor_value = checkpoint["metrics"].get(monitor_name) if monitor_value is not None: self.best_metric = float(monitor_value) if checkpoint.get("no_improve_epochs") is not None: self.no_improve_epochs = int(checkpoint["no_improve_epochs"]) if bool(self._checkpoint_cfg().get("resume_training", True)): self.start_epoch = int(checkpoint.get("epoch", -1)) + 1 return checkpoint def _validation_enabled(self) -> bool: return bool(self._validation_cfg().get("enabled", True)) def _validation_interval(self) -> int: return max(1, int(self._validation_cfg().get("interval", 1))) def _should_validate(self, epoch: int) -> bool: return self._validation_enabled() and ((epoch + 1) % self._validation_interval() == 0) def _metric_task_mode(self) -> str: validation_cfg = self._validation_cfg() metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics")) if isinstance(metrics_cfg, dict): return str(metrics_cfg.get("task_mode", "binary")) return "binary" def _metric_threshold(self) -> float: validation_cfg = self._validation_cfg() threshold = validation_cfg.get("threshold", 0.5) return float(threshold) def _build_validation_metrics(self) -> dict[str, Any]: validation_cfg = self._validation_cfg() metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics")) if metrics_cfg is None: return {} return build_metrics(metrics_cfg) def _early_stopping_enabled(self) -> bool: return bool(self._validation_cfg().get("early_stopping", False)) def _early_stopping_patience(self) -> int: return max(1, int(self._validation_cfg().get("early_stopping_patience", 10))) def _early_stopping_min_delta(self) -> float: return float(self._validation_cfg().get("early_stopping_min_delta", 0.0)) def _update_validation_metrics( self, metrics: dict[str, Any], *, logits: torch.Tensor, target: torch.Tensor, ) -> None: if not metrics: return update_metrics( metrics, logits, target, task_mode=self._metric_task_mode(), threshold=self._metric_threshold(), num_classes=int(self._dataset_cfg().get("num_classes", 1)), ) def _compute_validation_metric_values(self, metrics: dict[str, Any]) -> dict[str, float]: if not metrics: return {} values = compute_metrics(metrics) reset_metrics(metrics) return values def _init_swanlab(self) -> None: logging_cfg = self._logging_cfg() if not bool(logging_cfg.get("use_swanlab", False)): return if swanlab is None: print("SwanLab is not installed. Logging will continue without SwanLab.") return run_name = logging_cfg.get("experiment_name") or self.output_dir.name self.swanlab_run = swanlab.init( project=logging_cfg.get("project", "X_SSL_Net"), name=run_name, config=self.cfg, mode=logging_cfg.get("swanlab_mode"), ) def _log_metrics(self, metrics: dict[str, float], *, step: int) -> None: if self.swanlab_run is None: return swanlab.log(metrics, step=step) def _close_loggers(self) -> None: if self.swanlab_run is not None: swanlab.finish() self.swanlab_run = None def _log_interval(self) -> int: return max(1, int(self._logging_cfg().get("log_interval", 20))) def _grad_clip_cfg(self) -> dict[str, Any]: cfg = self.cfg.get("train", {}).get("grad_clip", {}) return cfg if isinstance(cfg, dict) else {} def _grad_clip_enabled(self) -> bool: return bool(self._grad_clip_cfg().get("enabled", False)) def _clip_gradients(self, module: nn.Module | None) -> float | None: if module is None or not self._grad_clip_enabled(): return None cfg = self._grad_clip_cfg() max_norm = float(cfg.get("max_norm", 1.0)) norm_type = float(cfg.get("norm_type", 2.0)) params = [param for param in module.parameters() if param.requires_grad and param.grad is not None] if not params: return None total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=norm_type) return float(total_norm.detach().cpu() if isinstance(total_norm, torch.Tensor) else total_norm) def _current_lrs(self, optimizer: Any | None) -> list[float]: if optimizer is None: return [] return [float(group.get("lr", 0.0)) for group in optimizer.param_groups] @staticmethod def _count_parameters(module: nn.Module | None) -> dict[str, int]: if module is None: return {"total": 0, "trainable": 0} total = sum(param.numel() for param in module.parameters()) trainable = sum(param.numel() for param in module.parameters() if param.requires_grad) return {"total": int(total), "trainable": int(trainable)} @staticmethod def _loader_summary(loader: Any | None) -> dict[str, Any] | None: if loader is None: return None dataset = getattr(loader, "dataset", None) return { "dataset_size": len(dataset) if dataset is not None else None, "num_batches": len(loader), "batch_size": getattr(loader, "batch_size", None), "num_workers": getattr(loader, "num_workers", None), "pin_memory": getattr(loader, "pin_memory", None), "persistent_workers": getattr(loader, "persistent_workers", None), "prefetch_factor": getattr(loader, "prefetch_factor", None), "drop_last": getattr(loader, "drop_last", None), } def _training_setup_summary( self, *, model_map: dict[str, nn.Module | None], loader_map: dict[str, Any | None], optimizer: Any | None = None, scheduler: Any | None = None, ) -> dict[str, Any]: return { "trainer": self.cfg.get("trainer", {}).get("name"), "device": str(self.device), "amp_enabled": self._amp_enabled(), "output_dir": str(self.output_dir), "start_epoch": self.start_epoch, "train": self.cfg.get("train", {}), "dataset": self.cfg.get("dataset", {}), "model": self.cfg.get("model", {}), "optimizer": self.cfg.get("optimizer", {}), "scheduler": self.cfg.get("scheduler"), "current_lrs": self._current_lrs(optimizer), "validation": self.cfg.get("validation", {}), "checkpoint": self.cfg.get("checkpoint", {}), "logging": self.cfg.get("logging", {}), "model_parameters": { name: self._count_parameters(module) for name, module in model_map.items() }, "loaders": { name: self._loader_summary(loader) for name, loader in loader_map.items() }, "cuda": { "available": torch.cuda.is_available(), "device_name": torch.cuda.get_device_name(self.device) if self.device.type == "cuda" else None, "device_count": torch.cuda.device_count(), }, } def _print_training_setup( self, *, model_map: dict[str, nn.Module | None], loader_map: dict[str, Any | None], optimizer: Any | None = None, scheduler: Any | None = None, ) -> None: if not bool(self._logging_cfg().get("print_training_setup", True)): return summary = self._training_setup_summary( model_map=model_map, loader_map=loader_map, optimizer=optimizer, scheduler=scheduler, ) print("========== TRAINING SETUP ==========") pprint.pprint(summary, sort_dicts=False, width=120) print("======== END TRAINING SETUP ========") def _gpu_memory_mb(self) -> float: if self.device.type != "cuda" or not torch.cuda.is_available(): return 0.0 return float(torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2)) def _performance_snapshot( self, *, epoch: int, step: int, num_steps: int, data_time: float, iter_time: float, metrics: dict[str, float], prefix: str = "train", ) -> dict[str, float | int]: snapshot: dict[str, float | int] = { "epoch": epoch, "step": step, "num_steps": num_steps, "data_time": data_time, "iter_time": iter_time, "gpu_memory_mb": self._gpu_memory_mb(), } lrs = self._current_lrs(getattr(self, "optimizer", None)) if lrs: snapshot["lr"] = lrs[0] for key, value in metrics.items(): snapshot[f"{prefix}_{key}"] = value return snapshot def _maybe_log_step( self, *, epoch: int, step: int, num_steps: int, data_time: float, iter_time: float, metrics: dict[str, float], prefix: str = "train", ) -> None: if step % self._log_interval() != 0 and step != num_steps: return snapshot = self._performance_snapshot( epoch=epoch, step=step, num_steps=num_steps, data_time=data_time, iter_time=iter_time, metrics=metrics, prefix=prefix, ) print(snapshot) log_metrics = { f"{prefix}/{key}": value for key, value in metrics.items() } log_metrics.update( { f"{prefix}/data_time": data_time, f"{prefix}/iter_time": iter_time, f"{prefix}/gpu_memory_mb": float(snapshot["gpu_memory_mb"]), } ) if "lr" in snapshot: log_metrics[f"{prefix}/lr"] = float(snapshot["lr"]) self._log_metrics(log_metrics, step=epoch * max(1, num_steps) + step) @staticmethod def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]: if steps <= 0: return {} return {key: value / steps for key, value in metric_sums.items()} def _base_checkpoint_state(self, *, epoch: int, metrics: dict[str, float] | None = None) -> dict[str, Any]: state = { "epoch": epoch, "cfg": self.cfg, "metrics": metrics or {}, "grad_scaler": self.grad_scaler.state_dict(), "no_improve_epochs": self.no_improve_epochs, } return state def _finalize_epoch( self, *, epoch: int, train_metrics: dict[str, float], val_metrics: dict[str, float] | None, checkpoint_state: dict[str, Any], ) -> tuple[dict[str, Any], bool]: merged_metrics = dict(train_metrics) if val_metrics is not None: merged_metrics.update({f"val_{key}": value for key, value in val_metrics.items()}) improved = False if val_metrics is not None: monitor_name = str(self._checkpoint_cfg().get("monitor", "total")) if monitor_name not in val_metrics: raise KeyError(f"Checkpoint monitor '{monitor_name}' not found in val metrics.") monitor_value = float(val_metrics[monitor_name]) delta = self._early_stopping_min_delta() previous_best = self.best_metric is_better = self._is_better_metric(monitor_value) if previous_best is not None and self._best_mode() == "max": is_better = monitor_value > (previous_best + delta) elif previous_best is not None and self._best_mode() == "min": is_better = monitor_value < (previous_best - delta) if is_better: self.best_metric = monitor_value self.no_improve_epochs = 0 improved = True best_state = dict(checkpoint_state) best_state.update( self._base_checkpoint_state( epoch=epoch, metrics=merged_metrics, ) ) best_state["best_metric"] = self.best_metric self._save_checkpoint("best.pth", best_state) else: self.no_improve_epochs += 1 save_last = bool(self._checkpoint_cfg().get("save_last", True)) if save_last: last_state = dict(checkpoint_state) last_state.update(self._base_checkpoint_state(epoch=epoch, metrics=merged_metrics)) if self.best_metric is not None: last_state["best_metric"] = self.best_metric self._save_checkpoint("last.pth", last_state) summary = {"epoch": epoch} summary.update(train_metrics) if val_metrics is not None: summary.update({f"val_{key}": value for key, value in val_metrics.items()}) if self.best_metric is not None: summary["best_metric"] = float(self.best_metric) summary["no_improve_epochs"] = self.no_improve_epochs lrs = self._current_lrs(getattr(self, "optimizer", None)) if lrs: summary["lr"] = lrs[0] self._log_metrics(summary, step=epoch) should_stop = False if val_metrics is not None and self._early_stopping_enabled(): should_stop = self.no_improve_epochs >= self._early_stopping_patience() summary["early_stop"] = should_stop summary["improved"] = improved return summary, should_stop @abstractmethod def build(self) -> None: """ 创建模型、优化器、数据加载器等运行所需对象。 """ @abstractmethod def train(self) -> None: """ 执行完整训练流程。 """