|
@@ -0,0 +1,636 @@
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+from abc import ABC, abstractmethod
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+import pprint
|
|
|
|
|
+import time
|
|
|
|
|
+from typing import Any
|
|
|
|
|
+
|
|
|
|
|
+import torch
|
|
|
|
|
+import torch.nn as nn
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
|
|
+from torch.amp import GradScaler
|
|
|
|
|
+
|
|
|
|
|
+from lib.data import build_dataloader
|
|
|
|
|
+from lib.tools import build_metrics, compute_metrics, reset_metrics, update_metrics
|
|
|
|
|
+
|
|
|
|
|
+try:
|
|
|
|
|
+ import swanlab
|
|
|
|
|
+except ImportError:
|
|
|
|
|
+ swanlab = None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class BaseTrainer(ABC):
|
|
|
|
|
+ """
|
|
|
|
|
+ 训练器基类。
|
|
|
|
|
+
|
|
|
|
|
+ 设计目标:
|
|
|
|
|
+ - 统一配置入口
|
|
|
|
|
+ - 统一模型/优化器/调度器创建
|
|
|
|
|
+ - 不同训练流程只重写最少的方法
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
|
|
|
|
|
+ self.cfg = cfg
|
|
|
|
|
+ self.args = args
|
|
|
|
|
+ self.device = self._build_device()
|
|
|
|
|
+ self.output_dir = self._build_output_dir()
|
|
|
|
|
+ self.start_epoch = 0
|
|
|
|
|
+ self.best_metric: float | None = None
|
|
|
|
|
+ self.no_improve_epochs = 0
|
|
|
|
|
+ self.swanlab_run = None
|
|
|
|
|
+ self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled())
|
|
|
|
|
+
|
|
|
|
|
+ def _build_device(self) -> torch.device:
|
|
|
|
|
+ device_name = self.cfg.get("train", {}).get("device", "cpu")
|
|
|
|
|
+ if device_name == "cuda" and not torch.cuda.is_available():
|
|
|
|
|
+ device_name = "cpu"
|
|
|
|
|
+ return torch.device(device_name)
|
|
|
|
|
+
|
|
|
|
|
+ def _build_output_dir(self) -> Path:
|
|
|
|
|
+ output_dir = self.cfg.get("checkpoint", {}).get("dir", "outputs/supervised_segmentation")
|
|
|
|
|
+ path = Path(output_dir)
|
|
|
|
|
+ path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ return path
|
|
|
|
|
+
|
|
|
|
|
+ def _amp_enabled(self) -> bool:
|
|
|
|
|
+ return bool(self.cfg.get("train", {}).get("amp", False)) and self.device.type == "cuda"
|
|
|
|
|
+
|
|
|
|
|
+ def _auto_batch_size_cfg(self) -> dict[str, Any]:
|
|
|
|
|
+ cfg = self.cfg.get("train", {}).get("auto_batch_size", {})
|
|
|
|
|
+ return cfg if isinstance(cfg, dict) else {}
|
|
|
|
|
+
|
|
|
|
|
+ def _auto_batch_size_enabled(self) -> bool:
|
|
|
|
|
+ return bool(self._auto_batch_size_cfg().get("enabled", False))
|
|
|
|
|
+
|
|
|
|
|
+ def _gpu_total_memory_gb(self) -> float | None:
|
|
|
|
|
+ if self.device.type != "cuda" or not torch.cuda.is_available():
|
|
|
|
|
+ return None
|
|
|
|
|
+ props = torch.cuda.get_device_properties(self.device)
|
|
|
|
|
+ return float(props.total_memory / (1024 ** 3))
|
|
|
|
|
+
|
|
|
|
|
+ def _estimate_auto_batch_size(self, *, default_batch_size: int, ssl: bool = False) -> int:
|
|
|
|
|
+ cfg = self._auto_batch_size_cfg()
|
|
|
|
|
+ if not cfg.get("enabled", False):
|
|
|
|
|
+ return int(default_batch_size)
|
|
|
|
|
+
|
|
|
|
|
+ total_gb = self._gpu_total_memory_gb()
|
|
|
|
|
+ if total_gb is None:
|
|
|
|
|
+ return int(default_batch_size)
|
|
|
|
|
+
|
|
|
|
|
+ target_fraction = float(cfg.get("target_memory_fraction", 0.75))
|
|
|
|
|
+ target_fraction = min(max(target_fraction, 0.1), 0.95)
|
|
|
|
|
+ reference_gpu_gb = float(cfg.get("reference_gpu_gb", 8.0))
|
|
|
|
|
+ reference_batch_size = int(cfg.get("reference_batch_size", default_batch_size))
|
|
|
|
|
+ max_batch_size = int(cfg.get("max_batch_size", reference_batch_size))
|
|
|
|
|
+ min_batch_size = int(cfg.get("min_batch_size", 1))
|
|
|
|
|
+
|
|
|
|
|
+ memory_penalty = float(cfg.get("memory_penalty", 1.0 if not ssl else 1.35))
|
|
|
|
|
+ scaled = int((reference_batch_size * total_gb * target_fraction) / max(reference_gpu_gb * 0.75 * memory_penalty, 1e-6))
|
|
|
|
|
+ batch_size = max(min_batch_size, min(max_batch_size, max(default_batch_size, scaled)))
|
|
|
|
|
+ return int(batch_size)
|
|
|
|
|
+
|
|
|
|
|
+ def _resolve_batch_size(self, key: str, default: int, *, ssl: bool = False) -> int:
|
|
|
|
|
+ train_cfg = self.cfg.get("train", {})
|
|
|
|
|
+ configured = int(train_cfg.get(key, default))
|
|
|
|
|
+ batch_size = self._estimate_auto_batch_size(default_batch_size=configured, ssl=ssl)
|
|
|
|
|
+ if self._auto_batch_size_enabled() and batch_size != configured:
|
|
|
|
|
+ print(
|
|
|
|
|
+ {
|
|
|
|
|
+ "message": "auto_batch_size adjusted",
|
|
|
|
|
+ "key": key,
|
|
|
|
|
+ "configured": configured,
|
|
|
|
|
+ "resolved": batch_size,
|
|
|
|
|
+ "gpu_total_gb": self._gpu_total_memory_gb(),
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+ return batch_size
|
|
|
|
|
+
|
|
|
|
|
+ def _dataset_cfg(self) -> dict[str, Any]:
|
|
|
|
|
+ return self.cfg.get("dataset", {})
|
|
|
|
|
+
|
|
|
|
|
+ def _dataset_name(self) -> str:
|
|
|
|
|
+ dataset_cfg = self._dataset_cfg()
|
|
|
|
|
+ dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
|
|
|
|
|
+ if not dataset_name:
|
|
|
|
|
+ raise ValueError("dataset.dataset_name is required.")
|
|
|
|
|
+ return str(dataset_name)
|
|
|
|
|
+
|
|
|
|
|
+ def _dataset_root(self) -> str:
|
|
|
|
|
+ dataset_cfg = self._dataset_cfg()
|
|
|
|
|
+ root = dataset_cfg.get("root")
|
|
|
|
|
+ if not root:
|
|
|
|
|
+ raise ValueError("dataset.root is required.")
|
|
|
|
|
+ return str(root)
|
|
|
|
|
+
|
|
|
|
|
+ def _image_size(self) -> tuple[int, int]:
|
|
|
|
|
+ dataset_cfg = self._dataset_cfg()
|
|
|
|
|
+ image_size = dataset_cfg.get("image_size")
|
|
|
|
|
+ if image_size is None:
|
|
|
|
|
+ raise ValueError("dataset.image_size is required.")
|
|
|
|
|
+ return int(image_size[0]), int(image_size[1])
|
|
|
|
|
+
|
|
|
|
|
+ def _build_resize_transform(self, *, mode: str) -> Any:
|
|
|
|
|
+ height, width = self._image_size()
|
|
|
|
|
+ interpolation_mode = "bilinear" if mode == "image" else "nearest"
|
|
|
|
|
+
|
|
|
|
|
+ def _transform(tensor: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
+ resized = F.interpolate(
|
|
|
|
|
+ tensor.unsqueeze(0),
|
|
|
|
|
+ size=(height, width),
|
|
|
|
|
+ mode=interpolation_mode,
|
|
|
|
|
+ align_corners=False if interpolation_mode != "nearest" else None,
|
|
|
|
|
+ )
|
|
|
|
|
+ return resized.squeeze(0)
|
|
|
|
|
+
|
|
|
|
|
+ return _transform
|
|
|
|
|
+
|
|
|
|
|
+ def _build_segmentation_loader(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ split: str,
|
|
|
|
|
+ batch_size: int,
|
|
|
|
|
+ shuffle: bool,
|
|
|
|
|
+ split_file: str | None = None,
|
|
|
|
|
+ ):
|
|
|
|
|
+ dataset_cfg = self._dataset_cfg()
|
|
|
|
|
+ train_cfg = self.cfg.get("train", {})
|
|
|
|
|
+ num_workers = max(0, int(train_cfg.get("num_workers", 0)))
|
|
|
|
|
+ persistent_workers = bool(train_cfg.get("persistent_workers", False)) if num_workers > 0 else False
|
|
|
|
|
+ loader = build_dataloader(
|
|
|
|
|
+ dataset_name=self._dataset_name(),
|
|
|
|
|
+ root=self._dataset_root(),
|
|
|
|
|
+ split=split,
|
|
|
|
|
+ split_file=split_file,
|
|
|
|
|
+ batch_size=batch_size,
|
|
|
|
|
+ shuffle=shuffle,
|
|
|
|
|
+ num_workers=num_workers,
|
|
|
|
|
+ image_transform=self._build_resize_transform(mode="image"),
|
|
|
|
|
+ mask_transform=self._build_resize_transform(mode="mask"),
|
|
|
|
|
+ pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")),
|
|
|
|
|
+ persistent_workers=persistent_workers,
|
|
|
|
|
+ prefetch_factor=train_cfg.get("prefetch_factor") if num_workers > 0 else None,
|
|
|
|
|
+ )
|
|
|
|
|
+ return loader
|
|
|
|
|
+
|
|
|
|
|
+ def _build_val_loader(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ batch_size: int,
|
|
|
|
|
+ shuffle: bool = False,
|
|
|
|
|
+ ):
|
|
|
|
|
+ dataset_cfg = self._dataset_cfg()
|
|
|
|
|
+ val_split = dataset_cfg.get("val_split", "val")
|
|
|
|
|
+ if val_split is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+ return self._build_segmentation_loader(
|
|
|
|
|
+ split=str(val_split),
|
|
|
|
|
+ split_file=dataset_cfg.get("val_split_file"),
|
|
|
|
|
+ batch_size=batch_size,
|
|
|
|
|
+ shuffle=shuffle,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _checkpoint_cfg(self) -> dict[str, Any]:
|
|
|
|
|
+ return self.cfg.get("checkpoint", {})
|
|
|
|
|
+
|
|
|
|
|
+ def _logging_cfg(self) -> dict[str, Any]:
|
|
|
|
|
+ return self.cfg.get("logging", {})
|
|
|
|
|
+
|
|
|
|
|
+ def _validation_cfg(self) -> dict[str, Any]:
|
|
|
|
|
+ return self.cfg.get("validation", {})
|
|
|
|
|
+
|
|
|
|
|
+ def _checkpoint_enabled(self) -> bool:
|
|
|
|
|
+ return bool(self._checkpoint_cfg().get("save", True))
|
|
|
|
|
+
|
|
|
|
|
+ def _best_mode(self) -> str:
|
|
|
|
|
+ return str(self._checkpoint_cfg().get("monitor_mode", "min"))
|
|
|
|
|
+
|
|
|
|
|
+ def _is_better_metric(self, metric: float) -> bool:
|
|
|
|
|
+ if self.best_metric is None:
|
|
|
|
|
+ return True
|
|
|
|
|
+ if self._best_mode() == "max":
|
|
|
|
|
+ return metric > self.best_metric
|
|
|
|
|
+ return metric < self.best_metric
|
|
|
|
|
+
|
|
|
|
|
+ def _save_checkpoint(self, filename: str, state: dict[str, Any]) -> Path | None:
|
|
|
|
|
+ if not self._checkpoint_enabled():
|
|
|
|
|
+ return None
|
|
|
|
|
+ path = self.output_dir / filename
|
|
|
|
|
+ torch.save(state, path)
|
|
|
|
|
+ return path
|
|
|
|
|
+
|
|
|
|
|
+ def _resume_checkpoint_path(self) -> Path | None:
|
|
|
|
|
+ resume_path = self._checkpoint_cfg().get("resume")
|
|
|
|
|
+ if not resume_path:
|
|
|
|
|
+ return None
|
|
|
|
|
+ path = Path(str(resume_path))
|
|
|
|
|
+ if not path.is_absolute():
|
|
|
|
|
+ path = Path.cwd() / path
|
|
|
|
|
+ return path
|
|
|
|
|
+
|
|
|
|
|
+ def _maybe_resume(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ module_map: dict[str, Any],
|
|
|
|
|
+ optimizer: Any | None = None,
|
|
|
|
|
+ scheduler: Any | None = None,
|
|
|
|
|
+ ) -> dict[str, Any] | None:
|
|
|
|
|
+ path = self._resume_checkpoint_path()
|
|
|
|
|
+ if path is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+ if not path.exists():
|
|
|
|
|
+ raise FileNotFoundError(f"Resume checkpoint not found: {path}")
|
|
|
|
|
+
|
|
|
|
|
+ checkpoint = torch.load(path, map_location="cpu")
|
|
|
|
|
+ strict = bool(self._checkpoint_cfg().get("resume_strict", True))
|
|
|
|
|
+ for key, module in module_map.items():
|
|
|
|
|
+ if module is None:
|
|
|
|
|
+ continue
|
|
|
|
|
+ state_dict = checkpoint.get(key)
|
|
|
|
|
+ if state_dict is not None:
|
|
|
|
|
+ module.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
+
|
|
|
|
|
+ if optimizer is not None and checkpoint.get("optimizer") is not None:
|
|
|
|
|
+ optimizer.load_state_dict(checkpoint["optimizer"])
|
|
|
|
|
+ if scheduler is not None and checkpoint.get("scheduler") is not None:
|
|
|
|
|
+ scheduler.load_state_dict(checkpoint["scheduler"])
|
|
|
|
|
+ if checkpoint.get("grad_scaler") is not None:
|
|
|
|
|
+ self.grad_scaler.load_state_dict(checkpoint["grad_scaler"])
|
|
|
|
|
+
|
|
|
|
|
+ if checkpoint.get("best_metric") is not None:
|
|
|
|
|
+ self.best_metric = float(checkpoint["best_metric"])
|
|
|
|
|
+ elif checkpoint.get("metrics") is not None:
|
|
|
|
|
+ monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
|
|
|
|
|
+ monitor_value = checkpoint["metrics"].get(f"val_{monitor_name}")
|
|
|
|
|
+ if monitor_value is None:
|
|
|
|
|
+ monitor_value = checkpoint["metrics"].get(monitor_name)
|
|
|
|
|
+ if monitor_value is not None:
|
|
|
|
|
+ self.best_metric = float(monitor_value)
|
|
|
|
|
+ if checkpoint.get("no_improve_epochs") is not None:
|
|
|
|
|
+ self.no_improve_epochs = int(checkpoint["no_improve_epochs"])
|
|
|
|
|
+
|
|
|
|
|
+ if bool(self._checkpoint_cfg().get("resume_training", True)):
|
|
|
|
|
+ self.start_epoch = int(checkpoint.get("epoch", -1)) + 1
|
|
|
|
|
+ return checkpoint
|
|
|
|
|
+
|
|
|
|
|
+ def _validation_enabled(self) -> bool:
|
|
|
|
|
+ return bool(self._validation_cfg().get("enabled", True))
|
|
|
|
|
+
|
|
|
|
|
+ def _validation_interval(self) -> int:
|
|
|
|
|
+ return max(1, int(self._validation_cfg().get("interval", 1)))
|
|
|
|
|
+
|
|
|
|
|
+ def _should_validate(self, epoch: int) -> bool:
|
|
|
|
|
+ return self._validation_enabled() and ((epoch + 1) % self._validation_interval() == 0)
|
|
|
|
|
+
|
|
|
|
|
+ def _metric_task_mode(self) -> str:
|
|
|
|
|
+ validation_cfg = self._validation_cfg()
|
|
|
|
|
+ metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
|
|
|
|
|
+ if isinstance(metrics_cfg, dict):
|
|
|
|
|
+ return str(metrics_cfg.get("task_mode", "binary"))
|
|
|
|
|
+ return "binary"
|
|
|
|
|
+
|
|
|
|
|
+ def _metric_threshold(self) -> float:
|
|
|
|
|
+ validation_cfg = self._validation_cfg()
|
|
|
|
|
+ threshold = validation_cfg.get("threshold", 0.5)
|
|
|
|
|
+ return float(threshold)
|
|
|
|
|
+
|
|
|
|
|
+ def _build_validation_metrics(self) -> dict[str, Any]:
|
|
|
|
|
+ validation_cfg = self._validation_cfg()
|
|
|
|
|
+ metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
|
|
|
|
|
+ if metrics_cfg is None:
|
|
|
|
|
+ return {}
|
|
|
|
|
+ return build_metrics(metrics_cfg)
|
|
|
|
|
+
|
|
|
|
|
+ def _early_stopping_enabled(self) -> bool:
|
|
|
|
|
+ return bool(self._validation_cfg().get("early_stopping", False))
|
|
|
|
|
+
|
|
|
|
|
+ def _early_stopping_patience(self) -> int:
|
|
|
|
|
+ return max(1, int(self._validation_cfg().get("early_stopping_patience", 10)))
|
|
|
|
|
+
|
|
|
|
|
+ def _early_stopping_min_delta(self) -> float:
|
|
|
|
|
+ return float(self._validation_cfg().get("early_stopping_min_delta", 0.0))
|
|
|
|
|
+
|
|
|
|
|
+ def _update_validation_metrics(
|
|
|
|
|
+ self,
|
|
|
|
|
+ metrics: dict[str, Any],
|
|
|
|
|
+ *,
|
|
|
|
|
+ logits: torch.Tensor,
|
|
|
|
|
+ target: torch.Tensor,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ if not metrics:
|
|
|
|
|
+ return
|
|
|
|
|
+ update_metrics(
|
|
|
|
|
+ metrics,
|
|
|
|
|
+ logits,
|
|
|
|
|
+ target,
|
|
|
|
|
+ task_mode=self._metric_task_mode(),
|
|
|
|
|
+ threshold=self._metric_threshold(),
|
|
|
|
|
+ num_classes=int(self._dataset_cfg().get("num_classes", 1)),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _compute_validation_metric_values(self, metrics: dict[str, Any]) -> dict[str, float]:
|
|
|
|
|
+ if not metrics:
|
|
|
|
|
+ return {}
|
|
|
|
|
+ values = compute_metrics(metrics)
|
|
|
|
|
+ reset_metrics(metrics)
|
|
|
|
|
+ return values
|
|
|
|
|
+
|
|
|
|
|
+ def _init_swanlab(self) -> None:
|
|
|
|
|
+ logging_cfg = self._logging_cfg()
|
|
|
|
|
+ if not bool(logging_cfg.get("use_swanlab", False)):
|
|
|
|
|
+ return
|
|
|
|
|
+ if swanlab is None:
|
|
|
|
|
+ print("SwanLab is not installed. Logging will continue without SwanLab.")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ run_name = logging_cfg.get("experiment_name") or self.output_dir.name
|
|
|
|
|
+ self.swanlab_run = swanlab.init(
|
|
|
|
|
+ project=logging_cfg.get("project", "X_SSL_Net"),
|
|
|
|
|
+ name=run_name,
|
|
|
|
|
+ config=self.cfg,
|
|
|
|
|
+ mode=logging_cfg.get("swanlab_mode"),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def _log_metrics(self, metrics: dict[str, float], *, step: int) -> None:
|
|
|
|
|
+ if self.swanlab_run is None:
|
|
|
|
|
+ return
|
|
|
|
|
+ swanlab.log(metrics, step=step)
|
|
|
|
|
+
|
|
|
|
|
+ def _close_loggers(self) -> None:
|
|
|
|
|
+ if self.swanlab_run is not None:
|
|
|
|
|
+ swanlab.finish()
|
|
|
|
|
+ self.swanlab_run = None
|
|
|
|
|
+
|
|
|
|
|
+ def _log_interval(self) -> int:
|
|
|
|
|
+ return max(1, int(self._logging_cfg().get("log_interval", 20)))
|
|
|
|
|
+
|
|
|
|
|
+ def _grad_clip_cfg(self) -> dict[str, Any]:
|
|
|
|
|
+ cfg = self.cfg.get("train", {}).get("grad_clip", {})
|
|
|
|
|
+ return cfg if isinstance(cfg, dict) else {}
|
|
|
|
|
+
|
|
|
|
|
+ def _grad_clip_enabled(self) -> bool:
|
|
|
|
|
+ return bool(self._grad_clip_cfg().get("enabled", False))
|
|
|
|
|
+
|
|
|
|
|
+ def _clip_gradients(self, module: nn.Module | None) -> float | None:
|
|
|
|
|
+ if module is None or not self._grad_clip_enabled():
|
|
|
|
|
+ return None
|
|
|
|
|
+ cfg = self._grad_clip_cfg()
|
|
|
|
|
+ max_norm = float(cfg.get("max_norm", 1.0))
|
|
|
|
|
+ norm_type = float(cfg.get("norm_type", 2.0))
|
|
|
|
|
+ params = [param for param in module.parameters() if param.requires_grad and param.grad is not None]
|
|
|
|
|
+ if not params:
|
|
|
|
|
+ return None
|
|
|
|
|
+ total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=norm_type)
|
|
|
|
|
+ return float(total_norm.detach().cpu() if isinstance(total_norm, torch.Tensor) else total_norm)
|
|
|
|
|
+
|
|
|
|
|
+ def _current_lrs(self, optimizer: Any | None) -> list[float]:
|
|
|
|
|
+ if optimizer is None:
|
|
|
|
|
+ return []
|
|
|
|
|
+ return [float(group.get("lr", 0.0)) for group in optimizer.param_groups]
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _count_parameters(module: nn.Module | None) -> dict[str, int]:
|
|
|
|
|
+ if module is None:
|
|
|
|
|
+ return {"total": 0, "trainable": 0}
|
|
|
|
|
+ total = sum(param.numel() for param in module.parameters())
|
|
|
|
|
+ trainable = sum(param.numel() for param in module.parameters() if param.requires_grad)
|
|
|
|
|
+ return {"total": int(total), "trainable": int(trainable)}
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _loader_summary(loader: Any | None) -> dict[str, Any] | None:
|
|
|
|
|
+ if loader is None:
|
|
|
|
|
+ return None
|
|
|
|
|
+ dataset = getattr(loader, "dataset", None)
|
|
|
|
|
+ return {
|
|
|
|
|
+ "dataset_size": len(dataset) if dataset is not None else None,
|
|
|
|
|
+ "num_batches": len(loader),
|
|
|
|
|
+ "batch_size": getattr(loader, "batch_size", None),
|
|
|
|
|
+ "num_workers": getattr(loader, "num_workers", None),
|
|
|
|
|
+ "pin_memory": getattr(loader, "pin_memory", None),
|
|
|
|
|
+ "persistent_workers": getattr(loader, "persistent_workers", None),
|
|
|
|
|
+ "prefetch_factor": getattr(loader, "prefetch_factor", None),
|
|
|
|
|
+ "drop_last": getattr(loader, "drop_last", None),
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _training_setup_summary(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ model_map: dict[str, nn.Module | None],
|
|
|
|
|
+ loader_map: dict[str, Any | None],
|
|
|
|
|
+ optimizer: Any | None = None,
|
|
|
|
|
+ scheduler: Any | None = None,
|
|
|
|
|
+ ) -> dict[str, Any]:
|
|
|
|
|
+ return {
|
|
|
|
|
+ "trainer": self.cfg.get("trainer", {}).get("name"),
|
|
|
|
|
+ "device": str(self.device),
|
|
|
|
|
+ "amp_enabled": self._amp_enabled(),
|
|
|
|
|
+ "output_dir": str(self.output_dir),
|
|
|
|
|
+ "start_epoch": self.start_epoch,
|
|
|
|
|
+ "train": self.cfg.get("train", {}),
|
|
|
|
|
+ "dataset": self.cfg.get("dataset", {}),
|
|
|
|
|
+ "model": self.cfg.get("model", {}),
|
|
|
|
|
+ "optimizer": self.cfg.get("optimizer", {}),
|
|
|
|
|
+ "scheduler": self.cfg.get("scheduler"),
|
|
|
|
|
+ "current_lrs": self._current_lrs(optimizer),
|
|
|
|
|
+ "validation": self.cfg.get("validation", {}),
|
|
|
|
|
+ "checkpoint": self.cfg.get("checkpoint", {}),
|
|
|
|
|
+ "logging": self.cfg.get("logging", {}),
|
|
|
|
|
+ "model_parameters": {
|
|
|
|
|
+ name: self._count_parameters(module)
|
|
|
|
|
+ for name, module in model_map.items()
|
|
|
|
|
+ },
|
|
|
|
|
+ "loaders": {
|
|
|
|
|
+ name: self._loader_summary(loader)
|
|
|
|
|
+ for name, loader in loader_map.items()
|
|
|
|
|
+ },
|
|
|
|
|
+ "cuda": {
|
|
|
|
|
+ "available": torch.cuda.is_available(),
|
|
|
|
|
+ "device_name": torch.cuda.get_device_name(self.device) if self.device.type == "cuda" else None,
|
|
|
|
|
+ "device_count": torch.cuda.device_count(),
|
|
|
|
|
+ },
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ def _print_training_setup(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ model_map: dict[str, nn.Module | None],
|
|
|
|
|
+ loader_map: dict[str, Any | None],
|
|
|
|
|
+ optimizer: Any | None = None,
|
|
|
|
|
+ scheduler: Any | None = None,
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ if not bool(self._logging_cfg().get("print_training_setup", True)):
|
|
|
|
|
+ return
|
|
|
|
|
+ summary = self._training_setup_summary(
|
|
|
|
|
+ model_map=model_map,
|
|
|
|
|
+ loader_map=loader_map,
|
|
|
|
|
+ optimizer=optimizer,
|
|
|
|
|
+ scheduler=scheduler,
|
|
|
|
|
+ )
|
|
|
|
|
+ print("========== TRAINING SETUP ==========")
|
|
|
|
|
+ pprint.pprint(summary, sort_dicts=False, width=120)
|
|
|
|
|
+ print("======== END TRAINING SETUP ========")
|
|
|
|
|
+
|
|
|
|
|
+ def _gpu_memory_mb(self) -> float:
|
|
|
|
|
+ if self.device.type != "cuda" or not torch.cuda.is_available():
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+ return float(torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2))
|
|
|
|
|
+
|
|
|
|
|
+ def _performance_snapshot(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ epoch: int,
|
|
|
|
|
+ step: int,
|
|
|
|
|
+ num_steps: int,
|
|
|
|
|
+ data_time: float,
|
|
|
|
|
+ iter_time: float,
|
|
|
|
|
+ metrics: dict[str, float],
|
|
|
|
|
+ prefix: str = "train",
|
|
|
|
|
+ ) -> dict[str, float | int]:
|
|
|
|
|
+ snapshot: dict[str, float | int] = {
|
|
|
|
|
+ "epoch": epoch,
|
|
|
|
|
+ "step": step,
|
|
|
|
|
+ "num_steps": num_steps,
|
|
|
|
|
+ "data_time": data_time,
|
|
|
|
|
+ "iter_time": iter_time,
|
|
|
|
|
+ "gpu_memory_mb": self._gpu_memory_mb(),
|
|
|
|
|
+ }
|
|
|
|
|
+ lrs = self._current_lrs(getattr(self, "optimizer", None))
|
|
|
|
|
+ if lrs:
|
|
|
|
|
+ snapshot["lr"] = lrs[0]
|
|
|
|
|
+ for key, value in metrics.items():
|
|
|
|
|
+ snapshot[f"{prefix}_{key}"] = value
|
|
|
|
|
+ return snapshot
|
|
|
|
|
+
|
|
|
|
|
+ def _maybe_log_step(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ epoch: int,
|
|
|
|
|
+ step: int,
|
|
|
|
|
+ num_steps: int,
|
|
|
|
|
+ data_time: float,
|
|
|
|
|
+ iter_time: float,
|
|
|
|
|
+ metrics: dict[str, float],
|
|
|
|
|
+ prefix: str = "train",
|
|
|
|
|
+ ) -> None:
|
|
|
|
|
+ if step % self._log_interval() != 0 and step != num_steps:
|
|
|
|
|
+ return
|
|
|
|
|
+ snapshot = self._performance_snapshot(
|
|
|
|
|
+ epoch=epoch,
|
|
|
|
|
+ step=step,
|
|
|
|
|
+ num_steps=num_steps,
|
|
|
|
|
+ data_time=data_time,
|
|
|
|
|
+ iter_time=iter_time,
|
|
|
|
|
+ metrics=metrics,
|
|
|
|
|
+ prefix=prefix,
|
|
|
|
|
+ )
|
|
|
|
|
+ print(snapshot)
|
|
|
|
|
+ log_metrics = {
|
|
|
|
|
+ f"{prefix}/{key}": value
|
|
|
|
|
+ for key, value in metrics.items()
|
|
|
|
|
+ }
|
|
|
|
|
+ log_metrics.update(
|
|
|
|
|
+ {
|
|
|
|
|
+ f"{prefix}/data_time": data_time,
|
|
|
|
|
+ f"{prefix}/iter_time": iter_time,
|
|
|
|
|
+ f"{prefix}/gpu_memory_mb": float(snapshot["gpu_memory_mb"]),
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+ if "lr" in snapshot:
|
|
|
|
|
+ log_metrics[f"{prefix}/lr"] = float(snapshot["lr"])
|
|
|
|
|
+ self._log_metrics(log_metrics, step=epoch * max(1, num_steps) + step)
|
|
|
|
|
+
|
|
|
|
|
+ @staticmethod
|
|
|
|
|
+ def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]:
|
|
|
|
|
+ if steps <= 0:
|
|
|
|
|
+ return {}
|
|
|
|
|
+ return {key: value / steps for key, value in metric_sums.items()}
|
|
|
|
|
+
|
|
|
|
|
+ def _base_checkpoint_state(self, *, epoch: int, metrics: dict[str, float] | None = None) -> dict[str, Any]:
|
|
|
|
|
+ state = {
|
|
|
|
|
+ "epoch": epoch,
|
|
|
|
|
+ "cfg": self.cfg,
|
|
|
|
|
+ "metrics": metrics or {},
|
|
|
|
|
+ "grad_scaler": self.grad_scaler.state_dict(),
|
|
|
|
|
+ "no_improve_epochs": self.no_improve_epochs,
|
|
|
|
|
+ }
|
|
|
|
|
+ return state
|
|
|
|
|
+
|
|
|
|
|
+ def _finalize_epoch(
|
|
|
|
|
+ self,
|
|
|
|
|
+ *,
|
|
|
|
|
+ epoch: int,
|
|
|
|
|
+ train_metrics: dict[str, float],
|
|
|
|
|
+ val_metrics: dict[str, float] | None,
|
|
|
|
|
+ checkpoint_state: dict[str, Any],
|
|
|
|
|
+ ) -> tuple[dict[str, Any], bool]:
|
|
|
|
|
+ merged_metrics = dict(train_metrics)
|
|
|
|
|
+ if val_metrics is not None:
|
|
|
|
|
+ merged_metrics.update({f"val_{key}": value for key, value in val_metrics.items()})
|
|
|
|
|
+
|
|
|
|
|
+ improved = False
|
|
|
|
|
+ if val_metrics is not None:
|
|
|
|
|
+ monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
|
|
|
|
|
+ if monitor_name not in val_metrics:
|
|
|
|
|
+ raise KeyError(f"Checkpoint monitor '{monitor_name}' not found in val metrics.")
|
|
|
|
|
+ monitor_value = float(val_metrics[monitor_name])
|
|
|
|
|
+ delta = self._early_stopping_min_delta()
|
|
|
|
|
+ previous_best = self.best_metric
|
|
|
|
|
+ is_better = self._is_better_metric(monitor_value)
|
|
|
|
|
+ if previous_best is not None and self._best_mode() == "max":
|
|
|
|
|
+ is_better = monitor_value > (previous_best + delta)
|
|
|
|
|
+ elif previous_best is not None and self._best_mode() == "min":
|
|
|
|
|
+ is_better = monitor_value < (previous_best - delta)
|
|
|
|
|
+
|
|
|
|
|
+ if is_better:
|
|
|
|
|
+ self.best_metric = monitor_value
|
|
|
|
|
+ self.no_improve_epochs = 0
|
|
|
|
|
+ improved = True
|
|
|
|
|
+ best_state = dict(checkpoint_state)
|
|
|
|
|
+ best_state.update(
|
|
|
|
|
+ self._base_checkpoint_state(
|
|
|
|
|
+ epoch=epoch,
|
|
|
|
|
+ metrics=merged_metrics,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ best_state["best_metric"] = self.best_metric
|
|
|
|
|
+ self._save_checkpoint("best.pth", best_state)
|
|
|
|
|
+ else:
|
|
|
|
|
+ self.no_improve_epochs += 1
|
|
|
|
|
+
|
|
|
|
|
+ save_last = bool(self._checkpoint_cfg().get("save_last", True))
|
|
|
|
|
+ if save_last:
|
|
|
|
|
+ last_state = dict(checkpoint_state)
|
|
|
|
|
+ last_state.update(self._base_checkpoint_state(epoch=epoch, metrics=merged_metrics))
|
|
|
|
|
+ if self.best_metric is not None:
|
|
|
|
|
+ last_state["best_metric"] = self.best_metric
|
|
|
|
|
+ self._save_checkpoint("last.pth", last_state)
|
|
|
|
|
+
|
|
|
|
|
+ summary = {"epoch": epoch}
|
|
|
|
|
+ summary.update(train_metrics)
|
|
|
|
|
+ if val_metrics is not None:
|
|
|
|
|
+ summary.update({f"val_{key}": value for key, value in val_metrics.items()})
|
|
|
|
|
+ if self.best_metric is not None:
|
|
|
|
|
+ summary["best_metric"] = float(self.best_metric)
|
|
|
|
|
+ summary["no_improve_epochs"] = self.no_improve_epochs
|
|
|
|
|
+ lrs = self._current_lrs(getattr(self, "optimizer", None))
|
|
|
|
|
+ if lrs:
|
|
|
|
|
+ summary["lr"] = lrs[0]
|
|
|
|
|
+ self._log_metrics(summary, step=epoch)
|
|
|
|
|
+ should_stop = False
|
|
|
|
|
+ if val_metrics is not None and self._early_stopping_enabled():
|
|
|
|
|
+ should_stop = self.no_improve_epochs >= self._early_stopping_patience()
|
|
|
|
|
+ summary["early_stop"] = should_stop
|
|
|
|
|
+ summary["improved"] = improved
|
|
|
|
|
+ return summary, should_stop
|
|
|
|
|
+
|
|
|
|
|
+ @abstractmethod
|
|
|
|
|
+ def build(self) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ 创建模型、优化器、数据加载器等运行所需对象。
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ @abstractmethod
|
|
|
|
|
+ def train(self) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ 执行完整训练流程。
|
|
|
|
|
+ """
|