| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662 |
- from __future__ import annotations
- from abc import ABC, abstractmethod
- from pathlib import Path
- import pprint
- import random
- import time
- from typing import Any
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.amp import GradScaler
- from lib.data import build_dataloader
- from lib.tools import build_metrics, compute_metrics, reset_metrics, update_metrics
- try:
- import swanlab
- except ImportError:
- swanlab = None
- class BaseTrainer(ABC):
- """
- 训练器基类。
- 设计目标:
- - 统一配置入口
- - 统一模型/优化器/调度器创建
- - 不同训练流程只重写最少的方法
- """
- def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
- self.cfg = cfg
- self.args = args
- self._set_random_seed()
- self.device = self._build_device()
- self.output_dir = self._build_output_dir()
- self.start_epoch = 0
- self.best_metric: float | None = None
- self.no_improve_epochs = 0
- self.swanlab_run = None
- self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled())
- def _set_random_seed(self) -> None:
- seed = int(self.cfg.get("train", {}).get("seed", 42))
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.is_available():
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- deterministic = bool(self.cfg.get("train", {}).get("deterministic", False))
- if deterministic:
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
- else:
- torch.backends.cudnn.deterministic = False
- torch.backends.cudnn.benchmark = True
- def _build_device(self) -> torch.device:
- device_name = self.cfg.get("train", {}).get("device", "cpu")
- if device_name == "cuda" and not torch.cuda.is_available():
- device_name = "cpu"
- return torch.device(device_name)
- def _build_output_dir(self) -> Path:
- output_dir = self.cfg.get("checkpoint", {}).get("dir", "outputs/supervised_segmentation")
- path = Path(output_dir)
- path.mkdir(parents=True, exist_ok=True)
- return path
- def _amp_enabled(self) -> bool:
- return bool(self.cfg.get("train", {}).get("amp", False)) and self.device.type == "cuda"
- def _auto_batch_size_cfg(self) -> dict[str, Any]:
- cfg = self.cfg.get("train", {}).get("auto_batch_size", {})
- return cfg if isinstance(cfg, dict) else {}
- def _auto_batch_size_enabled(self) -> bool:
- return bool(self._auto_batch_size_cfg().get("enabled", False))
- def _gpu_total_memory_gb(self) -> float | None:
- if self.device.type != "cuda" or not torch.cuda.is_available():
- return None
- props = torch.cuda.get_device_properties(self.device)
- return float(props.total_memory / (1024 ** 3))
- def _estimate_auto_batch_size(self, *, default_batch_size: int, ssl: bool = False) -> int:
- cfg = self._auto_batch_size_cfg()
- if not cfg.get("enabled", False):
- return int(default_batch_size)
- total_gb = self._gpu_total_memory_gb()
- if total_gb is None:
- return int(default_batch_size)
- target_fraction = float(cfg.get("target_memory_fraction", 0.75))
- target_fraction = min(max(target_fraction, 0.1), 0.95)
- reference_gpu_gb = float(cfg.get("reference_gpu_gb", 8.0))
- reference_batch_size = int(cfg.get("reference_batch_size", default_batch_size))
- max_batch_size = int(cfg.get("max_batch_size", reference_batch_size))
- min_batch_size = int(cfg.get("min_batch_size", 1))
- memory_penalty = float(cfg.get("memory_penalty", 1.0 if not ssl else 1.35))
- scaled = int((reference_batch_size * total_gb * target_fraction) / max(reference_gpu_gb * 0.75 * memory_penalty, 1e-6))
- batch_size = max(min_batch_size, min(max_batch_size, max(default_batch_size, scaled)))
- return int(batch_size)
- def _resolve_batch_size(self, key: str, default: int, *, ssl: bool = False) -> int:
- train_cfg = self.cfg.get("train", {})
- configured = int(train_cfg.get(key, default))
- batch_size = self._estimate_auto_batch_size(default_batch_size=configured, ssl=ssl)
- if self._auto_batch_size_enabled() and batch_size != configured:
- print(
- {
- "message": "auto_batch_size adjusted",
- "key": key,
- "configured": configured,
- "resolved": batch_size,
- "gpu_total_gb": self._gpu_total_memory_gb(),
- }
- )
- return batch_size
- def _dataset_cfg(self) -> dict[str, Any]:
- return self.cfg.get("dataset", {})
- def _dataset_name(self) -> str:
- dataset_cfg = self._dataset_cfg()
- dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
- if not dataset_name:
- raise ValueError("dataset.dataset_name is required.")
- return str(dataset_name)
- def _dataset_root(self) -> str:
- dataset_cfg = self._dataset_cfg()
- root = dataset_cfg.get("root")
- if not root:
- raise ValueError("dataset.root is required.")
- return str(root)
- def _image_size(self) -> tuple[int, int]:
- dataset_cfg = self._dataset_cfg()
- image_size = dataset_cfg.get("image_size")
- if image_size is None:
- raise ValueError("dataset.image_size is required.")
- return int(image_size[0]), int(image_size[1])
- def _build_resize_transform(self, *, mode: str) -> Any:
- height, width = self._image_size()
- interpolation_mode = "bilinear" if mode == "image" else "nearest"
- def _transform(tensor: torch.Tensor) -> torch.Tensor:
- resized = F.interpolate(
- tensor.unsqueeze(0),
- size=(height, width),
- mode=interpolation_mode,
- align_corners=False if interpolation_mode != "nearest" else None,
- )
- return resized.squeeze(0)
- return _transform
- def _build_segmentation_loader(
- self,
- *,
- split: str,
- batch_size: int,
- shuffle: bool,
- split_file: str | None = None,
- augmentation_config: dict[str, Any] | None = None,
- ):
- dataset_cfg = self._dataset_cfg()
- train_cfg = self.cfg.get("train", {})
- num_workers = max(0, int(train_cfg.get("num_workers", 0)))
- persistent_workers = bool(train_cfg.get("persistent_workers", False)) if num_workers > 0 else False
- loader = build_dataloader(
- dataset_name=self._dataset_name(),
- root=self._dataset_root(),
- split=split,
- split_file=split_file,
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- augmentation_config=augmentation_config,
- image_transform=self._build_resize_transform(mode="image"),
- mask_transform=self._build_resize_transform(mode="mask"),
- pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")),
- persistent_workers=persistent_workers,
- prefetch_factor=train_cfg.get("prefetch_factor") if num_workers > 0 else None,
- )
- return loader
- def _build_val_loader(
- self,
- *,
- batch_size: int,
- shuffle: bool = False,
- ):
- dataset_cfg = self._dataset_cfg()
- val_split = dataset_cfg.get("val_split", "val")
- if val_split is None:
- return None
- return self._build_segmentation_loader(
- split=str(val_split),
- split_file=dataset_cfg.get("val_split_file"),
- batch_size=batch_size,
- shuffle=shuffle,
- augmentation_config=self.cfg.get("augmentation", {}).get("val"),
- )
- def _checkpoint_cfg(self) -> dict[str, Any]:
- return self.cfg.get("checkpoint", {})
- def _logging_cfg(self) -> dict[str, Any]:
- return self.cfg.get("logging", {})
- def _validation_cfg(self) -> dict[str, Any]:
- return self.cfg.get("validation", {})
- def _checkpoint_enabled(self) -> bool:
- return bool(self._checkpoint_cfg().get("save", True))
- def _best_mode(self) -> str:
- return str(self._checkpoint_cfg().get("monitor_mode", "min"))
- def _is_better_metric(self, metric: float) -> bool:
- if self.best_metric is None:
- return True
- if self._best_mode() == "max":
- return metric > self.best_metric
- return metric < self.best_metric
- def _save_checkpoint(self, filename: str, state: dict[str, Any]) -> Path | None:
- if not self._checkpoint_enabled():
- return None
- path = self.output_dir / filename
- torch.save(state, path)
- return path
- def _resume_checkpoint_path(self) -> Path | None:
- resume_path = self._checkpoint_cfg().get("resume")
- if not resume_path:
- return None
- path = Path(str(resume_path))
- if not path.is_absolute():
- path = Path.cwd() / path
- return path
- def _maybe_resume(
- self,
- *,
- module_map: dict[str, Any],
- optimizer: Any | None = None,
- scheduler: Any | None = None,
- ) -> dict[str, Any] | None:
- path = self._resume_checkpoint_path()
- if path is None:
- return None
- if not path.exists():
- raise FileNotFoundError(f"Resume checkpoint not found: {path}")
- checkpoint = torch.load(path, map_location="cpu")
- strict = bool(self._checkpoint_cfg().get("resume_strict", True))
- for key, module in module_map.items():
- if module is None:
- continue
- state_dict = checkpoint.get(key)
- if state_dict is not None:
- module.load_state_dict(state_dict, strict=strict)
- if optimizer is not None and checkpoint.get("optimizer") is not None:
- optimizer.load_state_dict(checkpoint["optimizer"])
- if scheduler is not None and checkpoint.get("scheduler") is not None:
- scheduler.load_state_dict(checkpoint["scheduler"])
- if checkpoint.get("grad_scaler") is not None:
- self.grad_scaler.load_state_dict(checkpoint["grad_scaler"])
- if checkpoint.get("best_metric") is not None:
- self.best_metric = float(checkpoint["best_metric"])
- elif checkpoint.get("metrics") is not None:
- monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
- monitor_value = checkpoint["metrics"].get(f"val_{monitor_name}")
- if monitor_value is None:
- monitor_value = checkpoint["metrics"].get(monitor_name)
- if monitor_value is not None:
- self.best_metric = float(monitor_value)
- if checkpoint.get("no_improve_epochs") is not None:
- self.no_improve_epochs = int(checkpoint["no_improve_epochs"])
- if bool(self._checkpoint_cfg().get("resume_training", True)):
- self.start_epoch = int(checkpoint.get("epoch", -1)) + 1
- return checkpoint
- def _validation_enabled(self) -> bool:
- return bool(self._validation_cfg().get("enabled", True))
- def _validation_interval(self) -> int:
- return max(1, int(self._validation_cfg().get("interval", 1)))
- def _should_validate(self, epoch: int) -> bool:
- return self._validation_enabled() and ((epoch + 1) % self._validation_interval() == 0)
- def _metric_task_mode(self) -> str:
- validation_cfg = self._validation_cfg()
- metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
- if isinstance(metrics_cfg, dict):
- return str(metrics_cfg.get("task_mode", "binary"))
- return "binary"
- def _metric_threshold(self) -> float:
- validation_cfg = self._validation_cfg()
- threshold = validation_cfg.get("threshold", 0.5)
- return float(threshold)
- def _build_validation_metrics(self) -> dict[str, Any]:
- validation_cfg = self._validation_cfg()
- metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
- if metrics_cfg is None:
- return {}
- return build_metrics(metrics_cfg)
- def _early_stopping_enabled(self) -> bool:
- return bool(self._validation_cfg().get("early_stopping", False))
- def _early_stopping_patience(self) -> int:
- return max(1, int(self._validation_cfg().get("early_stopping_patience", 10)))
- def _early_stopping_min_delta(self) -> float:
- return float(self._validation_cfg().get("early_stopping_min_delta", 0.0))
- def _update_validation_metrics(
- self,
- metrics: dict[str, Any],
- *,
- logits: torch.Tensor,
- target: torch.Tensor,
- ) -> None:
- if not metrics:
- return
- update_metrics(
- metrics,
- logits,
- target,
- task_mode=self._metric_task_mode(),
- threshold=self._metric_threshold(),
- num_classes=int(self._dataset_cfg().get("num_classes", 1)),
- )
- def _compute_validation_metric_values(self, metrics: dict[str, Any]) -> dict[str, float]:
- if not metrics:
- return {}
- values = compute_metrics(metrics)
- reset_metrics(metrics)
- return values
- def _init_swanlab(self) -> None:
- logging_cfg = self._logging_cfg()
- if not bool(logging_cfg.get("use_swanlab", False)):
- return
- if swanlab is None:
- print("SwanLab is not installed. Logging will continue without SwanLab.")
- return
- run_name = logging_cfg.get("experiment_name") or self.output_dir.name
- self.swanlab_run = swanlab.init(
- project=logging_cfg.get("project", "X_SSL_Net"),
- name=run_name,
- config=self.cfg,
- logdir=logging_cfg.get("swanlab_logdir", "swanlog"),
- mode=logging_cfg.get("swanlab_mode"),
- )
- def _log_metrics(self, metrics: dict[str, float], *, step: int) -> None:
- if self.swanlab_run is None:
- return
- swanlab.log(metrics, step=step)
- def _close_loggers(self) -> None:
- if self.swanlab_run is not None:
- swanlab.finish()
- self.swanlab_run = None
- def _log_interval(self) -> int:
- return max(1, int(self._logging_cfg().get("log_interval", 20)))
- def _grad_clip_cfg(self) -> dict[str, Any]:
- cfg = self.cfg.get("train", {}).get("grad_clip", {})
- return cfg if isinstance(cfg, dict) else {}
- def _grad_clip_enabled(self) -> bool:
- return bool(self._grad_clip_cfg().get("enabled", False))
- def _accum_steps(self) -> int:
- return max(1, int(self.cfg.get("train", {}).get("accum_steps", 1)))
- def _clip_gradients(self, module: nn.Module | None) -> float | None:
- if module is None or not self._grad_clip_enabled():
- return None
- cfg = self._grad_clip_cfg()
- max_norm = float(cfg.get("max_norm", 1.0))
- norm_type = float(cfg.get("norm_type", 2.0))
- params = [param for param in module.parameters() if param.requires_grad and param.grad is not None]
- if not params:
- return None
- total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=norm_type)
- return float(total_norm.detach().cpu() if isinstance(total_norm, torch.Tensor) else total_norm)
- def _current_lrs(self, optimizer: Any | None) -> list[float]:
- if optimizer is None:
- return []
- return [float(group.get("lr", 0.0)) for group in optimizer.param_groups]
- @staticmethod
- def _count_parameters(module: nn.Module | None) -> dict[str, int]:
- if module is None:
- return {"total": 0, "trainable": 0}
- total = sum(param.numel() for param in module.parameters())
- trainable = sum(param.numel() for param in module.parameters() if param.requires_grad)
- return {"total": int(total), "trainable": int(trainable)}
- @staticmethod
- def _loader_summary(loader: Any | None) -> dict[str, Any] | None:
- if loader is None:
- return None
- dataset = getattr(loader, "dataset", None)
- return {
- "dataset_size": len(dataset) if dataset is not None else None,
- "num_batches": len(loader),
- "batch_size": getattr(loader, "batch_size", None),
- "num_workers": getattr(loader, "num_workers", None),
- "pin_memory": getattr(loader, "pin_memory", None),
- "persistent_workers": getattr(loader, "persistent_workers", None),
- "prefetch_factor": getattr(loader, "prefetch_factor", None),
- "drop_last": getattr(loader, "drop_last", None),
- }
- def _training_setup_summary(
- self,
- *,
- model_map: dict[str, nn.Module | None],
- loader_map: dict[str, Any | None],
- optimizer: Any | None = None,
- scheduler: Any | None = None,
- ) -> dict[str, Any]:
- return {
- "trainer": self.cfg.get("trainer", {}).get("name"),
- "device": str(self.device),
- "amp_enabled": self._amp_enabled(),
- "output_dir": str(self.output_dir),
- "start_epoch": self.start_epoch,
- "train": self.cfg.get("train", {}),
- "dataset": self.cfg.get("dataset", {}),
- "model": self.cfg.get("model", {}),
- "optimizer": self.cfg.get("optimizer", {}),
- "scheduler": self.cfg.get("scheduler"),
- "current_lrs": self._current_lrs(optimizer),
- "validation": self.cfg.get("validation", {}),
- "checkpoint": self.cfg.get("checkpoint", {}),
- "logging": self.cfg.get("logging", {}),
- "model_parameters": {
- name: self._count_parameters(module)
- for name, module in model_map.items()
- },
- "loaders": {
- name: self._loader_summary(loader)
- for name, loader in loader_map.items()
- },
- "cuda": {
- "available": torch.cuda.is_available(),
- "device_name": torch.cuda.get_device_name(self.device) if self.device.type == "cuda" else None,
- "device_count": torch.cuda.device_count(),
- },
- }
- def _print_training_setup(
- self,
- *,
- model_map: dict[str, nn.Module | None],
- loader_map: dict[str, Any | None],
- optimizer: Any | None = None,
- scheduler: Any | None = None,
- ) -> None:
- if not bool(self._logging_cfg().get("print_training_setup", True)):
- return
- summary = self._training_setup_summary(
- model_map=model_map,
- loader_map=loader_map,
- optimizer=optimizer,
- scheduler=scheduler,
- )
- print("========== TRAINING SETUP ==========")
- pprint.pprint(summary, sort_dicts=False, width=120)
- print("======== END TRAINING SETUP ========")
- def _gpu_memory_mb(self) -> float:
- if self.device.type != "cuda" or not torch.cuda.is_available():
- return 0.0
- return float(torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2))
- def _performance_snapshot(
- self,
- *,
- epoch: int,
- step: int,
- num_steps: int,
- data_time: float,
- iter_time: float,
- metrics: dict[str, float],
- prefix: str = "train",
- ) -> dict[str, float | int]:
- snapshot: dict[str, float | int] = {
- "epoch": epoch,
- "step": step,
- "num_steps": num_steps,
- "data_time": data_time,
- "iter_time": iter_time,
- "gpu_memory_mb": self._gpu_memory_mb(),
- }
- lrs = self._current_lrs(getattr(self, "optimizer", None))
- if lrs:
- snapshot["lr"] = lrs[0]
- for key, value in metrics.items():
- snapshot[f"{prefix}_{key}"] = value
- return snapshot
- def _maybe_log_step(
- self,
- *,
- epoch: int,
- step: int,
- num_steps: int,
- data_time: float,
- iter_time: float,
- metrics: dict[str, float],
- prefix: str = "train",
- ) -> None:
- if step % self._log_interval() != 0 and step != num_steps:
- return
- snapshot = self._performance_snapshot(
- epoch=epoch,
- step=step,
- num_steps=num_steps,
- data_time=data_time,
- iter_time=iter_time,
- metrics=metrics,
- prefix=prefix,
- )
- print(snapshot)
- log_metrics = {
- f"{prefix}/{key}": value
- for key, value in metrics.items()
- }
- log_metrics.update(
- {
- f"{prefix}/data_time": data_time,
- f"{prefix}/iter_time": iter_time,
- f"{prefix}/gpu_memory_mb": float(snapshot["gpu_memory_mb"]),
- }
- )
- if "lr" in snapshot:
- log_metrics[f"{prefix}/lr"] = float(snapshot["lr"])
- @staticmethod
- def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]:
- if steps <= 0:
- return {}
- return {key: value / steps for key, value in metric_sums.items()}
- def _base_checkpoint_state(self, *, epoch: int, metrics: dict[str, float] | None = None) -> dict[str, Any]:
- state = {
- "epoch": epoch,
- "cfg": self.cfg,
- "metrics": metrics or {},
- "grad_scaler": self.grad_scaler.state_dict(),
- "no_improve_epochs": self.no_improve_epochs,
- }
- return state
- def _finalize_epoch(
- self,
- *,
- epoch: int,
- train_metrics: dict[str, float],
- val_metrics: dict[str, float] | None,
- checkpoint_state: dict[str, Any],
- ) -> tuple[dict[str, Any], bool]:
- merged_metrics = dict(train_metrics)
- if val_metrics is not None:
- merged_metrics.update({f"val_{key}": value for key, value in val_metrics.items()})
- improved = False
- if val_metrics is not None:
- monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
- if monitor_name not in val_metrics:
- raise KeyError(f"Checkpoint monitor '{monitor_name}' not found in val metrics.")
- monitor_value = float(val_metrics[monitor_name])
- delta = self._early_stopping_min_delta()
- previous_best = self.best_metric
- is_better = self._is_better_metric(monitor_value)
- if previous_best is not None and self._best_mode() == "max":
- is_better = monitor_value > (previous_best + delta)
- elif previous_best is not None and self._best_mode() == "min":
- is_better = monitor_value < (previous_best - delta)
- if is_better:
- self.best_metric = monitor_value
- self.no_improve_epochs = 0
- improved = True
- best_state = dict(checkpoint_state)
- best_state.update(
- self._base_checkpoint_state(
- epoch=epoch,
- metrics=merged_metrics,
- )
- )
- best_state["best_metric"] = self.best_metric
- self._save_checkpoint("best.pth", best_state)
- else:
- self.no_improve_epochs += 1
- save_last = bool(self._checkpoint_cfg().get("save_last", True))
- if save_last:
- last_state = dict(checkpoint_state)
- last_state.update(self._base_checkpoint_state(epoch=epoch, metrics=merged_metrics))
- if self.best_metric is not None:
- last_state["best_metric"] = self.best_metric
- self._save_checkpoint("last.pth", last_state)
- summary = {"epoch": epoch}
- summary.update(train_metrics)
- if val_metrics is not None:
- summary.update({f"val_{key}": value for key, value in val_metrics.items()})
- if self.best_metric is not None:
- summary["best_metric"] = float(self.best_metric)
- summary["no_improve_epochs"] = self.no_improve_epochs
- lrs = self._current_lrs(getattr(self, "optimizer", None))
- if lrs:
- summary["lr"] = lrs[0]
- self._log_metrics(summary, step=epoch)
- should_stop = False
- if val_metrics is not None and self._early_stopping_enabled():
- should_stop = self.no_improve_epochs >= self._early_stopping_patience()
- summary["early_stop"] = should_stop
- summary["improved"] = improved
- return summary, should_stop
- @abstractmethod
- def build(self) -> None:
- """
- 创建模型、优化器、数据加载器等运行所需对象。
- """
- @abstractmethod
- def train(self) -> None:
- """
- 执行完整训练流程。
- """
|