base.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from pathlib import Path
  4. import pprint
  5. import random
  6. import time
  7. from typing import Any
  8. import numpy as np
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from torch.amp import GradScaler
  13. from lib.data import build_dataloader
  14. from lib.tools import build_metrics, compute_metrics, reset_metrics, update_metrics
  15. try:
  16. import swanlab
  17. except ImportError:
  18. swanlab = None
  19. class BaseTrainer(ABC):
  20. """
  21. 训练器基类。
  22. 设计目标:
  23. - 统一配置入口
  24. - 统一模型/优化器/调度器创建
  25. - 不同训练流程只重写最少的方法
  26. """
  27. def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
  28. self.cfg = cfg
  29. self.args = args
  30. self._set_random_seed()
  31. self.device = self._build_device()
  32. self.output_dir = self._build_output_dir()
  33. self.start_epoch = 0
  34. self.best_metric: float | None = None
  35. self.no_improve_epochs = 0
  36. self.swanlab_run = None
  37. self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled())
  38. def _set_random_seed(self) -> None:
  39. seed = int(self.cfg.get("train", {}).get("seed", 42))
  40. random.seed(seed)
  41. np.random.seed(seed)
  42. torch.manual_seed(seed)
  43. if torch.cuda.is_available():
  44. torch.cuda.manual_seed(seed)
  45. torch.cuda.manual_seed_all(seed)
  46. deterministic = bool(self.cfg.get("train", {}).get("deterministic", False))
  47. if deterministic:
  48. torch.backends.cudnn.deterministic = True
  49. torch.backends.cudnn.benchmark = False
  50. else:
  51. torch.backends.cudnn.deterministic = False
  52. torch.backends.cudnn.benchmark = True
  53. def _build_device(self) -> torch.device:
  54. device_name = self.cfg.get("train", {}).get("device", "cpu")
  55. if device_name == "cuda" and not torch.cuda.is_available():
  56. device_name = "cpu"
  57. return torch.device(device_name)
  58. def _build_output_dir(self) -> Path:
  59. output_dir = self.cfg.get("checkpoint", {}).get("dir", "outputs/supervised_segmentation")
  60. path = Path(output_dir)
  61. path.mkdir(parents=True, exist_ok=True)
  62. return path
  63. def _amp_enabled(self) -> bool:
  64. return bool(self.cfg.get("train", {}).get("amp", False)) and self.device.type == "cuda"
  65. def _auto_batch_size_cfg(self) -> dict[str, Any]:
  66. cfg = self.cfg.get("train", {}).get("auto_batch_size", {})
  67. return cfg if isinstance(cfg, dict) else {}
  68. def _auto_batch_size_enabled(self) -> bool:
  69. return bool(self._auto_batch_size_cfg().get("enabled", False))
  70. def _gpu_total_memory_gb(self) -> float | None:
  71. if self.device.type != "cuda" or not torch.cuda.is_available():
  72. return None
  73. props = torch.cuda.get_device_properties(self.device)
  74. return float(props.total_memory / (1024 ** 3))
  75. def _estimate_auto_batch_size(self, *, default_batch_size: int, ssl: bool = False) -> int:
  76. cfg = self._auto_batch_size_cfg()
  77. if not cfg.get("enabled", False):
  78. return int(default_batch_size)
  79. total_gb = self._gpu_total_memory_gb()
  80. if total_gb is None:
  81. return int(default_batch_size)
  82. target_fraction = float(cfg.get("target_memory_fraction", 0.75))
  83. target_fraction = min(max(target_fraction, 0.1), 0.95)
  84. reference_gpu_gb = float(cfg.get("reference_gpu_gb", 8.0))
  85. reference_batch_size = int(cfg.get("reference_batch_size", default_batch_size))
  86. max_batch_size = int(cfg.get("max_batch_size", reference_batch_size))
  87. min_batch_size = int(cfg.get("min_batch_size", 1))
  88. memory_penalty = float(cfg.get("memory_penalty", 1.0 if not ssl else 1.35))
  89. scaled = int((reference_batch_size * total_gb * target_fraction) / max(reference_gpu_gb * 0.75 * memory_penalty, 1e-6))
  90. batch_size = max(min_batch_size, min(max_batch_size, max(default_batch_size, scaled)))
  91. return int(batch_size)
  92. def _resolve_batch_size(self, key: str, default: int, *, ssl: bool = False) -> int:
  93. train_cfg = self.cfg.get("train", {})
  94. configured = int(train_cfg.get(key, default))
  95. batch_size = self._estimate_auto_batch_size(default_batch_size=configured, ssl=ssl)
  96. if self._auto_batch_size_enabled() and batch_size != configured:
  97. print(
  98. {
  99. "message": "auto_batch_size adjusted",
  100. "key": key,
  101. "configured": configured,
  102. "resolved": batch_size,
  103. "gpu_total_gb": self._gpu_total_memory_gb(),
  104. }
  105. )
  106. return batch_size
  107. def _dataset_cfg(self) -> dict[str, Any]:
  108. return self.cfg.get("dataset", {})
  109. def _dataset_name(self) -> str:
  110. dataset_cfg = self._dataset_cfg()
  111. dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
  112. if not dataset_name:
  113. raise ValueError("dataset.dataset_name is required.")
  114. return str(dataset_name)
  115. def _dataset_root(self) -> str:
  116. dataset_cfg = self._dataset_cfg()
  117. root = dataset_cfg.get("root")
  118. if not root:
  119. raise ValueError("dataset.root is required.")
  120. return str(root)
  121. def _image_size(self) -> tuple[int, int]:
  122. dataset_cfg = self._dataset_cfg()
  123. image_size = dataset_cfg.get("image_size")
  124. if image_size is None:
  125. raise ValueError("dataset.image_size is required.")
  126. return int(image_size[0]), int(image_size[1])
  127. def _build_resize_transform(self, *, mode: str) -> Any:
  128. height, width = self._image_size()
  129. interpolation_mode = "bilinear" if mode == "image" else "nearest"
  130. def _transform(tensor: torch.Tensor) -> torch.Tensor:
  131. resized = F.interpolate(
  132. tensor.unsqueeze(0),
  133. size=(height, width),
  134. mode=interpolation_mode,
  135. align_corners=False if interpolation_mode != "nearest" else None,
  136. )
  137. return resized.squeeze(0)
  138. return _transform
  139. def _build_segmentation_loader(
  140. self,
  141. *,
  142. split: str,
  143. batch_size: int,
  144. shuffle: bool,
  145. split_file: str | None = None,
  146. augmentation_config: dict[str, Any] | None = None,
  147. ):
  148. dataset_cfg = self._dataset_cfg()
  149. train_cfg = self.cfg.get("train", {})
  150. num_workers = max(0, int(train_cfg.get("num_workers", 0)))
  151. persistent_workers = bool(train_cfg.get("persistent_workers", False)) if num_workers > 0 else False
  152. loader = build_dataloader(
  153. dataset_name=self._dataset_name(),
  154. root=self._dataset_root(),
  155. split=split,
  156. split_file=split_file,
  157. batch_size=batch_size,
  158. shuffle=shuffle,
  159. num_workers=num_workers,
  160. augmentation_config=augmentation_config,
  161. image_transform=self._build_resize_transform(mode="image"),
  162. mask_transform=self._build_resize_transform(mode="mask"),
  163. pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")),
  164. persistent_workers=persistent_workers,
  165. prefetch_factor=train_cfg.get("prefetch_factor") if num_workers > 0 else None,
  166. )
  167. return loader
  168. def _build_val_loader(
  169. self,
  170. *,
  171. batch_size: int,
  172. shuffle: bool = False,
  173. ):
  174. dataset_cfg = self._dataset_cfg()
  175. val_split = dataset_cfg.get("val_split", "val")
  176. if val_split is None:
  177. return None
  178. return self._build_segmentation_loader(
  179. split=str(val_split),
  180. split_file=dataset_cfg.get("val_split_file"),
  181. batch_size=batch_size,
  182. shuffle=shuffle,
  183. augmentation_config=self.cfg.get("augmentation", {}).get("val"),
  184. )
  185. def _checkpoint_cfg(self) -> dict[str, Any]:
  186. return self.cfg.get("checkpoint", {})
  187. def _logging_cfg(self) -> dict[str, Any]:
  188. return self.cfg.get("logging", {})
  189. def _validation_cfg(self) -> dict[str, Any]:
  190. return self.cfg.get("validation", {})
  191. def _checkpoint_enabled(self) -> bool:
  192. return bool(self._checkpoint_cfg().get("save", True))
  193. def _best_mode(self) -> str:
  194. return str(self._checkpoint_cfg().get("monitor_mode", "min"))
  195. def _is_better_metric(self, metric: float) -> bool:
  196. if self.best_metric is None:
  197. return True
  198. if self._best_mode() == "max":
  199. return metric > self.best_metric
  200. return metric < self.best_metric
  201. def _save_checkpoint(self, filename: str, state: dict[str, Any]) -> Path | None:
  202. if not self._checkpoint_enabled():
  203. return None
  204. path = self.output_dir / filename
  205. torch.save(state, path)
  206. return path
  207. def _resume_checkpoint_path(self) -> Path | None:
  208. resume_path = self._checkpoint_cfg().get("resume")
  209. if not resume_path:
  210. return None
  211. path = Path(str(resume_path))
  212. if not path.is_absolute():
  213. path = Path.cwd() / path
  214. return path
  215. def _maybe_resume(
  216. self,
  217. *,
  218. module_map: dict[str, Any],
  219. optimizer: Any | None = None,
  220. scheduler: Any | None = None,
  221. ) -> dict[str, Any] | None:
  222. path = self._resume_checkpoint_path()
  223. if path is None:
  224. return None
  225. if not path.exists():
  226. raise FileNotFoundError(f"Resume checkpoint not found: {path}")
  227. checkpoint = torch.load(path, map_location="cpu")
  228. strict = bool(self._checkpoint_cfg().get("resume_strict", True))
  229. for key, module in module_map.items():
  230. if module is None:
  231. continue
  232. state_dict = checkpoint.get(key)
  233. if state_dict is not None:
  234. module.load_state_dict(state_dict, strict=strict)
  235. if optimizer is not None and checkpoint.get("optimizer") is not None:
  236. optimizer.load_state_dict(checkpoint["optimizer"])
  237. if scheduler is not None and checkpoint.get("scheduler") is not None:
  238. scheduler.load_state_dict(checkpoint["scheduler"])
  239. if checkpoint.get("grad_scaler") is not None:
  240. self.grad_scaler.load_state_dict(checkpoint["grad_scaler"])
  241. if checkpoint.get("best_metric") is not None:
  242. self.best_metric = float(checkpoint["best_metric"])
  243. elif checkpoint.get("metrics") is not None:
  244. monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
  245. monitor_value = checkpoint["metrics"].get(f"val_{monitor_name}")
  246. if monitor_value is None:
  247. monitor_value = checkpoint["metrics"].get(monitor_name)
  248. if monitor_value is not None:
  249. self.best_metric = float(monitor_value)
  250. if checkpoint.get("no_improve_epochs") is not None:
  251. self.no_improve_epochs = int(checkpoint["no_improve_epochs"])
  252. if bool(self._checkpoint_cfg().get("resume_training", True)):
  253. self.start_epoch = int(checkpoint.get("epoch", -1)) + 1
  254. return checkpoint
  255. def _validation_enabled(self) -> bool:
  256. return bool(self._validation_cfg().get("enabled", True))
  257. def _validation_interval(self) -> int:
  258. return max(1, int(self._validation_cfg().get("interval", 1)))
  259. def _should_validate(self, epoch: int) -> bool:
  260. return self._validation_enabled() and ((epoch + 1) % self._validation_interval() == 0)
  261. def _metric_task_mode(self) -> str:
  262. validation_cfg = self._validation_cfg()
  263. metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
  264. if isinstance(metrics_cfg, dict):
  265. return str(metrics_cfg.get("task_mode", "binary"))
  266. return "binary"
  267. def _metric_threshold(self) -> float:
  268. validation_cfg = self._validation_cfg()
  269. threshold = validation_cfg.get("threshold", 0.5)
  270. return float(threshold)
  271. def _build_validation_metrics(self) -> dict[str, Any]:
  272. validation_cfg = self._validation_cfg()
  273. metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
  274. if metrics_cfg is None:
  275. return {}
  276. return build_metrics(metrics_cfg)
  277. def _early_stopping_enabled(self) -> bool:
  278. return bool(self._validation_cfg().get("early_stopping", False))
  279. def _early_stopping_patience(self) -> int:
  280. return max(1, int(self._validation_cfg().get("early_stopping_patience", 10)))
  281. def _early_stopping_min_delta(self) -> float:
  282. return float(self._validation_cfg().get("early_stopping_min_delta", 0.0))
  283. def _update_validation_metrics(
  284. self,
  285. metrics: dict[str, Any],
  286. *,
  287. logits: torch.Tensor,
  288. target: torch.Tensor,
  289. ) -> None:
  290. if not metrics:
  291. return
  292. update_metrics(
  293. metrics,
  294. logits,
  295. target,
  296. task_mode=self._metric_task_mode(),
  297. threshold=self._metric_threshold(),
  298. num_classes=int(self._dataset_cfg().get("num_classes", 1)),
  299. )
  300. def _compute_validation_metric_values(self, metrics: dict[str, Any]) -> dict[str, float]:
  301. if not metrics:
  302. return {}
  303. values = compute_metrics(metrics)
  304. reset_metrics(metrics)
  305. return values
  306. def _init_swanlab(self) -> None:
  307. logging_cfg = self._logging_cfg()
  308. if not bool(logging_cfg.get("use_swanlab", False)):
  309. return
  310. if swanlab is None:
  311. print("SwanLab is not installed. Logging will continue without SwanLab.")
  312. return
  313. run_name = logging_cfg.get("experiment_name") or self.output_dir.name
  314. self.swanlab_run = swanlab.init(
  315. project=logging_cfg.get("project", "X_SSL_Net"),
  316. name=run_name,
  317. config=self.cfg,
  318. logdir=logging_cfg.get("swanlab_logdir", "swanlog"),
  319. mode=logging_cfg.get("swanlab_mode"),
  320. )
  321. def _log_metrics(self, metrics: dict[str, float], *, step: int) -> None:
  322. if self.swanlab_run is None:
  323. return
  324. swanlab.log(metrics, step=step)
  325. def _close_loggers(self) -> None:
  326. if self.swanlab_run is not None:
  327. swanlab.finish()
  328. self.swanlab_run = None
  329. def _log_interval(self) -> int:
  330. return max(1, int(self._logging_cfg().get("log_interval", 20)))
  331. def _grad_clip_cfg(self) -> dict[str, Any]:
  332. cfg = self.cfg.get("train", {}).get("grad_clip", {})
  333. return cfg if isinstance(cfg, dict) else {}
  334. def _grad_clip_enabled(self) -> bool:
  335. return bool(self._grad_clip_cfg().get("enabled", False))
  336. def _accum_steps(self) -> int:
  337. return max(1, int(self.cfg.get("train", {}).get("accum_steps", 1)))
  338. def _clip_gradients(self, module: nn.Module | None) -> float | None:
  339. if module is None or not self._grad_clip_enabled():
  340. return None
  341. cfg = self._grad_clip_cfg()
  342. max_norm = float(cfg.get("max_norm", 1.0))
  343. norm_type = float(cfg.get("norm_type", 2.0))
  344. params = [param for param in module.parameters() if param.requires_grad and param.grad is not None]
  345. if not params:
  346. return None
  347. total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=norm_type)
  348. return float(total_norm.detach().cpu() if isinstance(total_norm, torch.Tensor) else total_norm)
  349. def _current_lrs(self, optimizer: Any | None) -> list[float]:
  350. if optimizer is None:
  351. return []
  352. return [float(group.get("lr", 0.0)) for group in optimizer.param_groups]
  353. @staticmethod
  354. def _count_parameters(module: nn.Module | None) -> dict[str, int]:
  355. if module is None:
  356. return {"total": 0, "trainable": 0}
  357. total = sum(param.numel() for param in module.parameters())
  358. trainable = sum(param.numel() for param in module.parameters() if param.requires_grad)
  359. return {"total": int(total), "trainable": int(trainable)}
  360. @staticmethod
  361. def _loader_summary(loader: Any | None) -> dict[str, Any] | None:
  362. if loader is None:
  363. return None
  364. dataset = getattr(loader, "dataset", None)
  365. return {
  366. "dataset_size": len(dataset) if dataset is not None else None,
  367. "num_batches": len(loader),
  368. "batch_size": getattr(loader, "batch_size", None),
  369. "num_workers": getattr(loader, "num_workers", None),
  370. "pin_memory": getattr(loader, "pin_memory", None),
  371. "persistent_workers": getattr(loader, "persistent_workers", None),
  372. "prefetch_factor": getattr(loader, "prefetch_factor", None),
  373. "drop_last": getattr(loader, "drop_last", None),
  374. }
  375. def _training_setup_summary(
  376. self,
  377. *,
  378. model_map: dict[str, nn.Module | None],
  379. loader_map: dict[str, Any | None],
  380. optimizer: Any | None = None,
  381. scheduler: Any | None = None,
  382. ) -> dict[str, Any]:
  383. return {
  384. "trainer": self.cfg.get("trainer", {}).get("name"),
  385. "device": str(self.device),
  386. "amp_enabled": self._amp_enabled(),
  387. "output_dir": str(self.output_dir),
  388. "start_epoch": self.start_epoch,
  389. "train": self.cfg.get("train", {}),
  390. "dataset": self.cfg.get("dataset", {}),
  391. "model": self.cfg.get("model", {}),
  392. "optimizer": self.cfg.get("optimizer", {}),
  393. "scheduler": self.cfg.get("scheduler"),
  394. "current_lrs": self._current_lrs(optimizer),
  395. "validation": self.cfg.get("validation", {}),
  396. "checkpoint": self.cfg.get("checkpoint", {}),
  397. "logging": self.cfg.get("logging", {}),
  398. "model_parameters": {
  399. name: self._count_parameters(module)
  400. for name, module in model_map.items()
  401. },
  402. "loaders": {
  403. name: self._loader_summary(loader)
  404. for name, loader in loader_map.items()
  405. },
  406. "cuda": {
  407. "available": torch.cuda.is_available(),
  408. "device_name": torch.cuda.get_device_name(self.device) if self.device.type == "cuda" else None,
  409. "device_count": torch.cuda.device_count(),
  410. },
  411. }
  412. def _print_training_setup(
  413. self,
  414. *,
  415. model_map: dict[str, nn.Module | None],
  416. loader_map: dict[str, Any | None],
  417. optimizer: Any | None = None,
  418. scheduler: Any | None = None,
  419. ) -> None:
  420. if not bool(self._logging_cfg().get("print_training_setup", True)):
  421. return
  422. summary = self._training_setup_summary(
  423. model_map=model_map,
  424. loader_map=loader_map,
  425. optimizer=optimizer,
  426. scheduler=scheduler,
  427. )
  428. print("========== TRAINING SETUP ==========")
  429. pprint.pprint(summary, sort_dicts=False, width=120)
  430. print("======== END TRAINING SETUP ========")
  431. def _gpu_memory_mb(self) -> float:
  432. if self.device.type != "cuda" or not torch.cuda.is_available():
  433. return 0.0
  434. return float(torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2))
  435. def _performance_snapshot(
  436. self,
  437. *,
  438. epoch: int,
  439. step: int,
  440. num_steps: int,
  441. data_time: float,
  442. iter_time: float,
  443. metrics: dict[str, float],
  444. prefix: str = "train",
  445. ) -> dict[str, float | int]:
  446. snapshot: dict[str, float | int] = {
  447. "epoch": epoch,
  448. "step": step,
  449. "num_steps": num_steps,
  450. "data_time": data_time,
  451. "iter_time": iter_time,
  452. "gpu_memory_mb": self._gpu_memory_mb(),
  453. }
  454. lrs = self._current_lrs(getattr(self, "optimizer", None))
  455. if lrs:
  456. snapshot["lr"] = lrs[0]
  457. for key, value in metrics.items():
  458. snapshot[f"{prefix}_{key}"] = value
  459. return snapshot
  460. def _maybe_log_step(
  461. self,
  462. *,
  463. epoch: int,
  464. step: int,
  465. num_steps: int,
  466. data_time: float,
  467. iter_time: float,
  468. metrics: dict[str, float],
  469. prefix: str = "train",
  470. ) -> None:
  471. if step % self._log_interval() != 0 and step != num_steps:
  472. return
  473. snapshot = self._performance_snapshot(
  474. epoch=epoch,
  475. step=step,
  476. num_steps=num_steps,
  477. data_time=data_time,
  478. iter_time=iter_time,
  479. metrics=metrics,
  480. prefix=prefix,
  481. )
  482. print(snapshot)
  483. log_metrics = {
  484. f"{prefix}/{key}": value
  485. for key, value in metrics.items()
  486. }
  487. log_metrics.update(
  488. {
  489. f"{prefix}/data_time": data_time,
  490. f"{prefix}/iter_time": iter_time,
  491. f"{prefix}/gpu_memory_mb": float(snapshot["gpu_memory_mb"]),
  492. }
  493. )
  494. if "lr" in snapshot:
  495. log_metrics[f"{prefix}/lr"] = float(snapshot["lr"])
  496. @staticmethod
  497. def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]:
  498. if steps <= 0:
  499. return {}
  500. return {key: value / steps for key, value in metric_sums.items()}
  501. def _base_checkpoint_state(self, *, epoch: int, metrics: dict[str, float] | None = None) -> dict[str, Any]:
  502. state = {
  503. "epoch": epoch,
  504. "cfg": self.cfg,
  505. "metrics": metrics or {},
  506. "grad_scaler": self.grad_scaler.state_dict(),
  507. "no_improve_epochs": self.no_improve_epochs,
  508. }
  509. return state
  510. def _finalize_epoch(
  511. self,
  512. *,
  513. epoch: int,
  514. train_metrics: dict[str, float],
  515. val_metrics: dict[str, float] | None,
  516. checkpoint_state: dict[str, Any],
  517. ) -> tuple[dict[str, Any], bool]:
  518. merged_metrics = dict(train_metrics)
  519. if val_metrics is not None:
  520. merged_metrics.update({f"val_{key}": value for key, value in val_metrics.items()})
  521. improved = False
  522. if val_metrics is not None:
  523. monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
  524. if monitor_name not in val_metrics:
  525. raise KeyError(f"Checkpoint monitor '{monitor_name}' not found in val metrics.")
  526. monitor_value = float(val_metrics[monitor_name])
  527. delta = self._early_stopping_min_delta()
  528. previous_best = self.best_metric
  529. is_better = self._is_better_metric(monitor_value)
  530. if previous_best is not None and self._best_mode() == "max":
  531. is_better = monitor_value > (previous_best + delta)
  532. elif previous_best is not None and self._best_mode() == "min":
  533. is_better = monitor_value < (previous_best - delta)
  534. if is_better:
  535. self.best_metric = monitor_value
  536. self.no_improve_epochs = 0
  537. improved = True
  538. best_state = dict(checkpoint_state)
  539. best_state.update(
  540. self._base_checkpoint_state(
  541. epoch=epoch,
  542. metrics=merged_metrics,
  543. )
  544. )
  545. best_state["best_metric"] = self.best_metric
  546. self._save_checkpoint("best.pth", best_state)
  547. else:
  548. self.no_improve_epochs += 1
  549. save_last = bool(self._checkpoint_cfg().get("save_last", True))
  550. if save_last:
  551. last_state = dict(checkpoint_state)
  552. last_state.update(self._base_checkpoint_state(epoch=epoch, metrics=merged_metrics))
  553. if self.best_metric is not None:
  554. last_state["best_metric"] = self.best_metric
  555. self._save_checkpoint("last.pth", last_state)
  556. summary = {"epoch": epoch}
  557. summary.update(train_metrics)
  558. if val_metrics is not None:
  559. summary.update({f"val_{key}": value for key, value in val_metrics.items()})
  560. if self.best_metric is not None:
  561. summary["best_metric"] = float(self.best_metric)
  562. summary["no_improve_epochs"] = self.no_improve_epochs
  563. lrs = self._current_lrs(getattr(self, "optimizer", None))
  564. if lrs:
  565. summary["lr"] = lrs[0]
  566. self._log_metrics(summary, step=epoch)
  567. should_stop = False
  568. if val_metrics is not None and self._early_stopping_enabled():
  569. should_stop = self.no_improve_epochs >= self._early_stopping_patience()
  570. summary["early_stop"] = should_stop
  571. summary["improved"] = improved
  572. return summary, should_stop
  573. @abstractmethod
  574. def build(self) -> None:
  575. """
  576. 创建模型、优化器、数据加载器等运行所需对象。
  577. """
  578. @abstractmethod
  579. def train(self) -> None:
  580. """
  581. 执行完整训练流程。
  582. """