from __future__ import annotations from abc import ABC, abstractmethod from pathlib import Path import pprint import random import time from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.amp import GradScaler from lib.data import build_dataloader from lib.tools import build_metrics, compute_metrics, reset_metrics, update_metrics try: import swanlab except ImportError: swanlab = None class BaseTrainer(ABC): """ 训练器基类。 设计目标: - 统一配置入口 - 统一模型/优化器/调度器创建 - 不同训练流程只重写最少的方法 """ def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None: self.cfg = cfg self.args = args self._set_random_seed() self.device = self._build_device() self.output_dir = self._build_output_dir() self.start_epoch = 0 self.best_metric: float | None = None self.no_improve_epochs = 0 self.swanlab_run = None self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled()) def _set_random_seed(self) -> None: seed = int(self.cfg.get("train", {}).get("seed", 42)) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) deterministic = bool(self.cfg.get("train", {}).get("deterministic", False)) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False else: torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True def _build_device(self) -> torch.device: device_name = self.cfg.get("train", {}).get("device", "cpu") if device_name == "cuda" and not torch.cuda.is_available(): device_name = "cpu" return torch.device(device_name) def _build_output_dir(self) -> Path: output_dir = self.cfg.get("checkpoint", {}).get("dir", "outputs/supervised_segmentation") path = Path(output_dir) path.mkdir(parents=True, exist_ok=True) return path def _amp_enabled(self) -> bool: return bool(self.cfg.get("train", {}).get("amp", False)) and self.device.type == "cuda" def _auto_batch_size_cfg(self) -> dict[str, Any]: cfg = self.cfg.get("train", {}).get("auto_batch_size", {}) return cfg if isinstance(cfg, dict) else {} def _auto_batch_size_enabled(self) -> bool: return bool(self._auto_batch_size_cfg().get("enabled", False)) def _gpu_total_memory_gb(self) -> float | None: if self.device.type != "cuda" or not torch.cuda.is_available(): return None props = torch.cuda.get_device_properties(self.device) return float(props.total_memory / (1024 ** 3)) def _estimate_auto_batch_size(self, *, default_batch_size: int, ssl: bool = False) -> int: cfg = self._auto_batch_size_cfg() if not cfg.get("enabled", False): return int(default_batch_size) total_gb = self._gpu_total_memory_gb() if total_gb is None: return int(default_batch_size) target_fraction = float(cfg.get("target_memory_fraction", 0.75)) target_fraction = min(max(target_fraction, 0.1), 0.95) reference_gpu_gb = float(cfg.get("reference_gpu_gb", 8.0)) reference_batch_size = int(cfg.get("reference_batch_size", default_batch_size)) max_batch_size = int(cfg.get("max_batch_size", reference_batch_size)) min_batch_size = int(cfg.get("min_batch_size", 1)) memory_penalty = float(cfg.get("memory_penalty", 1.0 if not ssl else 1.35)) scaled = int((reference_batch_size * total_gb * target_fraction) / max(reference_gpu_gb * 0.75 * memory_penalty, 1e-6)) batch_size = max(min_batch_size, min(max_batch_size, max(default_batch_size, scaled))) return int(batch_size) def _resolve_batch_size(self, key: str, default: int, *, ssl: bool = False) -> int: train_cfg = self.cfg.get("train", {}) configured = int(train_cfg.get(key, default)) batch_size = self._estimate_auto_batch_size(default_batch_size=configured, ssl=ssl) if self._auto_batch_size_enabled() and batch_size != configured: print( { "message": "auto_batch_size adjusted", "key": key, "configured": configured, "resolved": batch_size, "gpu_total_gb": self._gpu_total_memory_gb(), } ) return batch_size def _dataset_cfg(self) -> dict[str, Any]: return self.cfg.get("dataset", {}) def _dataset_name(self) -> str: dataset_cfg = self._dataset_cfg() dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name") if not dataset_name: raise ValueError("dataset.dataset_name is required.") return str(dataset_name) def _dataset_root(self) -> str: dataset_cfg = self._dataset_cfg() root = dataset_cfg.get("root") if not root: raise ValueError("dataset.root is required.") return str(root) def _image_size(self) -> tuple[int, int]: dataset_cfg = self._dataset_cfg() image_size = dataset_cfg.get("image_size") if image_size is None: raise ValueError("dataset.image_size is required.") return int(image_size[0]), int(image_size[1]) def _build_resize_transform(self, *, mode: str) -> Any: height, width = self._image_size() interpolation_mode = "bilinear" if mode == "image" else "nearest" def _transform(tensor: torch.Tensor) -> torch.Tensor: resized = F.interpolate( tensor.unsqueeze(0), size=(height, width), mode=interpolation_mode, align_corners=False if interpolation_mode != "nearest" else None, ) return resized.squeeze(0) return _transform def _build_segmentation_loader( self, *, split: str, batch_size: int, shuffle: bool, split_file: str | None = None, augmentation_config: dict[str, Any] | None = None, ): dataset_cfg = self._dataset_cfg() train_cfg = self.cfg.get("train", {}) num_workers = max(0, int(train_cfg.get("num_workers", 0))) persistent_workers = bool(train_cfg.get("persistent_workers", False)) if num_workers > 0 else False loader = build_dataloader( dataset_name=self._dataset_name(), root=self._dataset_root(), split=split, split_file=split_file, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, augmentation_config=augmentation_config, image_transform=self._build_resize_transform(mode="image"), mask_transform=self._build_resize_transform(mode="mask"), pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")), persistent_workers=persistent_workers, prefetch_factor=train_cfg.get("prefetch_factor") if num_workers > 0 else None, ) return loader def _build_val_loader( self, *, batch_size: int, shuffle: bool = False, ): dataset_cfg = self._dataset_cfg() val_split = dataset_cfg.get("val_split", "val") if val_split is None: return None return self._build_segmentation_loader( split=str(val_split), split_file=dataset_cfg.get("val_split_file"), batch_size=batch_size, shuffle=shuffle, augmentation_config=self.cfg.get("augmentation", {}).get("val"), ) def _checkpoint_cfg(self) -> dict[str, Any]: return self.cfg.get("checkpoint", {}) def _logging_cfg(self) -> dict[str, Any]: return self.cfg.get("logging", {}) def _validation_cfg(self) -> dict[str, Any]: return self.cfg.get("validation", {}) def _checkpoint_enabled(self) -> bool: return bool(self._checkpoint_cfg().get("save", True)) def _best_mode(self) -> str: return str(self._checkpoint_cfg().get("monitor_mode", "min")) def _is_better_metric(self, metric: float) -> bool: if self.best_metric is None: return True if self._best_mode() == "max": return metric > self.best_metric return metric < self.best_metric def _save_checkpoint(self, filename: str, state: dict[str, Any]) -> Path | None: if not self._checkpoint_enabled(): return None path = self.output_dir / filename torch.save(state, path) return path def _resume_checkpoint_path(self) -> Path | None: resume_path = self._checkpoint_cfg().get("resume") if not resume_path: return None path = Path(str(resume_path)) if not path.is_absolute(): path = Path.cwd() / path return path def _maybe_resume( self, *, module_map: dict[str, Any], optimizer: Any | None = None, scheduler: Any | None = None, ) -> dict[str, Any] | None: path = self._resume_checkpoint_path() if path is None: return None if not path.exists(): raise FileNotFoundError(f"Resume checkpoint not found: {path}") checkpoint = torch.load(path, map_location="cpu") strict = bool(self._checkpoint_cfg().get("resume_strict", True)) for key, module in module_map.items(): if module is None: continue state_dict = checkpoint.get(key) if state_dict is not None: module.load_state_dict(state_dict, strict=strict) if optimizer is not None and checkpoint.get("optimizer") is not None: optimizer.load_state_dict(checkpoint["optimizer"]) if scheduler is not None and checkpoint.get("scheduler") is not None: scheduler.load_state_dict(checkpoint["scheduler"]) if checkpoint.get("grad_scaler") is not None: self.grad_scaler.load_state_dict(checkpoint["grad_scaler"]) if checkpoint.get("best_metric") is not None: self.best_metric = float(checkpoint["best_metric"]) elif checkpoint.get("metrics") is not None: monitor_name = str(self._checkpoint_cfg().get("monitor", "total")) monitor_value = checkpoint["metrics"].get(f"val_{monitor_name}") if monitor_value is None: monitor_value = checkpoint["metrics"].get(monitor_name) if monitor_value is not None: self.best_metric = float(monitor_value) if checkpoint.get("no_improve_epochs") is not None: self.no_improve_epochs = int(checkpoint["no_improve_epochs"]) if bool(self._checkpoint_cfg().get("resume_training", True)): self.start_epoch = int(checkpoint.get("epoch", -1)) + 1 return checkpoint def _validation_enabled(self) -> bool: return bool(self._validation_cfg().get("enabled", True)) def _validation_interval(self) -> int: return max(1, int(self._validation_cfg().get("interval", 1))) def _should_validate(self, epoch: int) -> bool: return self._validation_enabled() and ((epoch + 1) % self._validation_interval() == 0) def _metric_task_mode(self) -> str: validation_cfg = self._validation_cfg() metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics")) if isinstance(metrics_cfg, dict): return str(metrics_cfg.get("task_mode", "binary")) return "binary" def _metric_threshold(self) -> float: validation_cfg = self._validation_cfg() threshold = validation_cfg.get("threshold", 0.5) return float(threshold) def _build_validation_metrics(self) -> dict[str, Any]: validation_cfg = self._validation_cfg() metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics")) if metrics_cfg is None: return {} return build_metrics(metrics_cfg) def _early_stopping_enabled(self) -> bool: return bool(self._validation_cfg().get("early_stopping", False)) def _early_stopping_patience(self) -> int: return max(1, int(self._validation_cfg().get("early_stopping_patience", 10))) def _early_stopping_min_delta(self) -> float: return float(self._validation_cfg().get("early_stopping_min_delta", 0.0)) def _update_validation_metrics( self, metrics: dict[str, Any], *, logits: torch.Tensor, target: torch.Tensor, ) -> None: if not metrics: return update_metrics( metrics, logits, target, task_mode=self._metric_task_mode(), threshold=self._metric_threshold(), num_classes=int(self._dataset_cfg().get("num_classes", 1)), ) def _compute_validation_metric_values(self, metrics: dict[str, Any]) -> dict[str, float]: if not metrics: return {} values = compute_metrics(metrics) reset_metrics(metrics) return values def _init_swanlab(self) -> None: logging_cfg = self._logging_cfg() if not bool(logging_cfg.get("use_swanlab", False)): return if swanlab is None: print("SwanLab is not installed. Logging will continue without SwanLab.") return run_name = logging_cfg.get("experiment_name") or self.output_dir.name self.swanlab_run = swanlab.init( project=logging_cfg.get("project", "X_SSL_Net"), name=run_name, config=self.cfg, logdir=logging_cfg.get("swanlab_logdir", "swanlog"), mode=logging_cfg.get("swanlab_mode"), ) def _log_metrics(self, metrics: dict[str, float], *, step: int) -> None: if self.swanlab_run is None: return swanlab.log(metrics, step=step) def _close_loggers(self) -> None: if self.swanlab_run is not None: swanlab.finish() self.swanlab_run = None def _log_interval(self) -> int: return max(1, int(self._logging_cfg().get("log_interval", 20))) def _grad_clip_cfg(self) -> dict[str, Any]: cfg = self.cfg.get("train", {}).get("grad_clip", {}) return cfg if isinstance(cfg, dict) else {} def _grad_clip_enabled(self) -> bool: return bool(self._grad_clip_cfg().get("enabled", False)) def _accum_steps(self) -> int: return max(1, int(self.cfg.get("train", {}).get("accum_steps", 1))) def _clip_gradients(self, module: nn.Module | None) -> float | None: if module is None or not self._grad_clip_enabled(): return None cfg = self._grad_clip_cfg() max_norm = float(cfg.get("max_norm", 1.0)) norm_type = float(cfg.get("norm_type", 2.0)) params = [param for param in module.parameters() if param.requires_grad and param.grad is not None] if not params: return None total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=norm_type) return float(total_norm.detach().cpu() if isinstance(total_norm, torch.Tensor) else total_norm) def _current_lrs(self, optimizer: Any | None) -> list[float]: if optimizer is None: return [] return [float(group.get("lr", 0.0)) for group in optimizer.param_groups] @staticmethod def _count_parameters(module: nn.Module | None) -> dict[str, int]: if module is None: return {"total": 0, "trainable": 0} total = sum(param.numel() for param in module.parameters()) trainable = sum(param.numel() for param in module.parameters() if param.requires_grad) return {"total": int(total), "trainable": int(trainable)} @staticmethod def _loader_summary(loader: Any | None) -> dict[str, Any] | None: if loader is None: return None dataset = getattr(loader, "dataset", None) return { "dataset_size": len(dataset) if dataset is not None else None, "num_batches": len(loader), "batch_size": getattr(loader, "batch_size", None), "num_workers": getattr(loader, "num_workers", None), "pin_memory": getattr(loader, "pin_memory", None), "persistent_workers": getattr(loader, "persistent_workers", None), "prefetch_factor": getattr(loader, "prefetch_factor", None), "drop_last": getattr(loader, "drop_last", None), } def _training_setup_summary( self, *, model_map: dict[str, nn.Module | None], loader_map: dict[str, Any | None], optimizer: Any | None = None, scheduler: Any | None = None, ) -> dict[str, Any]: return { "trainer": self.cfg.get("trainer", {}).get("name"), "device": str(self.device), "amp_enabled": self._amp_enabled(), "output_dir": str(self.output_dir), "start_epoch": self.start_epoch, "train": self.cfg.get("train", {}), "dataset": self.cfg.get("dataset", {}), "model": self.cfg.get("model", {}), "optimizer": self.cfg.get("optimizer", {}), "scheduler": self.cfg.get("scheduler"), "current_lrs": self._current_lrs(optimizer), "validation": self.cfg.get("validation", {}), "checkpoint": self.cfg.get("checkpoint", {}), "logging": self.cfg.get("logging", {}), "model_parameters": { name: self._count_parameters(module) for name, module in model_map.items() }, "loaders": { name: self._loader_summary(loader) for name, loader in loader_map.items() }, "cuda": { "available": torch.cuda.is_available(), "device_name": torch.cuda.get_device_name(self.device) if self.device.type == "cuda" else None, "device_count": torch.cuda.device_count(), }, } def _print_training_setup( self, *, model_map: dict[str, nn.Module | None], loader_map: dict[str, Any | None], optimizer: Any | None = None, scheduler: Any | None = None, ) -> None: if not bool(self._logging_cfg().get("print_training_setup", True)): return summary = self._training_setup_summary( model_map=model_map, loader_map=loader_map, optimizer=optimizer, scheduler=scheduler, ) print("========== TRAINING SETUP ==========") pprint.pprint(summary, sort_dicts=False, width=120) print("======== END TRAINING SETUP ========") def _gpu_memory_mb(self) -> float: if self.device.type != "cuda" or not torch.cuda.is_available(): return 0.0 return float(torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2)) def _performance_snapshot( self, *, epoch: int, step: int, num_steps: int, data_time: float, iter_time: float, metrics: dict[str, float], prefix: str = "train", ) -> dict[str, float | int]: snapshot: dict[str, float | int] = { "epoch": epoch, "step": step, "num_steps": num_steps, "data_time": data_time, "iter_time": iter_time, "gpu_memory_mb": self._gpu_memory_mb(), } lrs = self._current_lrs(getattr(self, "optimizer", None)) if lrs: snapshot["lr"] = lrs[0] for key, value in metrics.items(): snapshot[f"{prefix}_{key}"] = value return snapshot def _maybe_log_step( self, *, epoch: int, step: int, num_steps: int, data_time: float, iter_time: float, metrics: dict[str, float], prefix: str = "train", ) -> None: if step % self._log_interval() != 0 and step != num_steps: return snapshot = self._performance_snapshot( epoch=epoch, step=step, num_steps=num_steps, data_time=data_time, iter_time=iter_time, metrics=metrics, prefix=prefix, ) print(snapshot) log_metrics = { f"{prefix}/{key}": value for key, value in metrics.items() } log_metrics.update( { f"{prefix}/data_time": data_time, f"{prefix}/iter_time": iter_time, f"{prefix}/gpu_memory_mb": float(snapshot["gpu_memory_mb"]), } ) if "lr" in snapshot: log_metrics[f"{prefix}/lr"] = float(snapshot["lr"]) @staticmethod def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]: if steps <= 0: return {} return {key: value / steps for key, value in metric_sums.items()} def _base_checkpoint_state(self, *, epoch: int, metrics: dict[str, float] | None = None) -> dict[str, Any]: state = { "epoch": epoch, "cfg": self.cfg, "metrics": metrics or {}, "grad_scaler": self.grad_scaler.state_dict(), "no_improve_epochs": self.no_improve_epochs, } return state def _finalize_epoch( self, *, epoch: int, train_metrics: dict[str, float], val_metrics: dict[str, float] | None, checkpoint_state: dict[str, Any], ) -> tuple[dict[str, Any], bool]: merged_metrics = dict(train_metrics) if val_metrics is not None: merged_metrics.update({f"val_{key}": value for key, value in val_metrics.items()}) improved = False if val_metrics is not None: monitor_name = str(self._checkpoint_cfg().get("monitor", "total")) if monitor_name not in val_metrics: raise KeyError(f"Checkpoint monitor '{monitor_name}' not found in val metrics.") monitor_value = float(val_metrics[monitor_name]) delta = self._early_stopping_min_delta() previous_best = self.best_metric is_better = self._is_better_metric(monitor_value) if previous_best is not None and self._best_mode() == "max": is_better = monitor_value > (previous_best + delta) elif previous_best is not None and self._best_mode() == "min": is_better = monitor_value < (previous_best - delta) if is_better: self.best_metric = monitor_value self.no_improve_epochs = 0 improved = True best_state = dict(checkpoint_state) best_state.update( self._base_checkpoint_state( epoch=epoch, metrics=merged_metrics, ) ) best_state["best_metric"] = self.best_metric self._save_checkpoint("best.pth", best_state) else: self.no_improve_epochs += 1 save_last = bool(self._checkpoint_cfg().get("save_last", True)) if save_last: last_state = dict(checkpoint_state) last_state.update(self._base_checkpoint_state(epoch=epoch, metrics=merged_metrics)) if self.best_metric is not None: last_state["best_metric"] = self.best_metric self._save_checkpoint("last.pth", last_state) summary = {"epoch": epoch} summary.update(train_metrics) if val_metrics is not None: summary.update({f"val_{key}": value for key, value in val_metrics.items()}) if self.best_metric is not None: summary["best_metric"] = float(self.best_metric) summary["no_improve_epochs"] = self.no_improve_epochs lrs = self._current_lrs(getattr(self, "optimizer", None)) if lrs: summary["lr"] = lrs[0] self._log_metrics(summary, step=epoch) should_stop = False if val_metrics is not None and self._early_stopping_enabled(): should_stop = self.no_improve_epochs >= self._early_stopping_patience() summary["early_stop"] = should_stop summary["improved"] = improved return summary, should_stop @abstractmethod def build(self) -> None: """ 创建模型、优化器、数据加载器等运行所需对象。 """ @abstractmethod def train(self) -> None: """ 执行完整训练流程。 """