base.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from pathlib import Path
  4. import pprint
  5. import time
  6. from typing import Any
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from torch.amp import GradScaler
  11. from lib.data import build_dataloader
  12. from lib.tools import build_metrics, compute_metrics, reset_metrics, update_metrics
  13. try:
  14. import swanlab
  15. except ImportError:
  16. swanlab = None
  17. class BaseTrainer(ABC):
  18. """
  19. 训练器基类。
  20. 设计目标:
  21. - 统一配置入口
  22. - 统一模型/优化器/调度器创建
  23. - 不同训练流程只重写最少的方法
  24. """
  25. def __init__(self, cfg: dict[str, Any], args: Any | None = None) -> None:
  26. self.cfg = cfg
  27. self.args = args
  28. self.device = self._build_device()
  29. self.output_dir = self._build_output_dir()
  30. self.start_epoch = 0
  31. self.best_metric: float | None = None
  32. self.no_improve_epochs = 0
  33. self.swanlab_run = None
  34. self.grad_scaler = GradScaler("cuda", enabled=self._amp_enabled())
  35. def _build_device(self) -> torch.device:
  36. device_name = self.cfg.get("train", {}).get("device", "cpu")
  37. if device_name == "cuda" and not torch.cuda.is_available():
  38. device_name = "cpu"
  39. return torch.device(device_name)
  40. def _build_output_dir(self) -> Path:
  41. output_dir = self.cfg.get("checkpoint", {}).get("dir", "outputs/supervised_segmentation")
  42. path = Path(output_dir)
  43. path.mkdir(parents=True, exist_ok=True)
  44. return path
  45. def _amp_enabled(self) -> bool:
  46. return bool(self.cfg.get("train", {}).get("amp", False)) and self.device.type == "cuda"
  47. def _auto_batch_size_cfg(self) -> dict[str, Any]:
  48. cfg = self.cfg.get("train", {}).get("auto_batch_size", {})
  49. return cfg if isinstance(cfg, dict) else {}
  50. def _auto_batch_size_enabled(self) -> bool:
  51. return bool(self._auto_batch_size_cfg().get("enabled", False))
  52. def _gpu_total_memory_gb(self) -> float | None:
  53. if self.device.type != "cuda" or not torch.cuda.is_available():
  54. return None
  55. props = torch.cuda.get_device_properties(self.device)
  56. return float(props.total_memory / (1024 ** 3))
  57. def _estimate_auto_batch_size(self, *, default_batch_size: int, ssl: bool = False) -> int:
  58. cfg = self._auto_batch_size_cfg()
  59. if not cfg.get("enabled", False):
  60. return int(default_batch_size)
  61. total_gb = self._gpu_total_memory_gb()
  62. if total_gb is None:
  63. return int(default_batch_size)
  64. target_fraction = float(cfg.get("target_memory_fraction", 0.75))
  65. target_fraction = min(max(target_fraction, 0.1), 0.95)
  66. reference_gpu_gb = float(cfg.get("reference_gpu_gb", 8.0))
  67. reference_batch_size = int(cfg.get("reference_batch_size", default_batch_size))
  68. max_batch_size = int(cfg.get("max_batch_size", reference_batch_size))
  69. min_batch_size = int(cfg.get("min_batch_size", 1))
  70. memory_penalty = float(cfg.get("memory_penalty", 1.0 if not ssl else 1.35))
  71. scaled = int((reference_batch_size * total_gb * target_fraction) / max(reference_gpu_gb * 0.75 * memory_penalty, 1e-6))
  72. batch_size = max(min_batch_size, min(max_batch_size, max(default_batch_size, scaled)))
  73. return int(batch_size)
  74. def _resolve_batch_size(self, key: str, default: int, *, ssl: bool = False) -> int:
  75. train_cfg = self.cfg.get("train", {})
  76. configured = int(train_cfg.get(key, default))
  77. batch_size = self._estimate_auto_batch_size(default_batch_size=configured, ssl=ssl)
  78. if self._auto_batch_size_enabled() and batch_size != configured:
  79. print(
  80. {
  81. "message": "auto_batch_size adjusted",
  82. "key": key,
  83. "configured": configured,
  84. "resolved": batch_size,
  85. "gpu_total_gb": self._gpu_total_memory_gb(),
  86. }
  87. )
  88. return batch_size
  89. def _dataset_cfg(self) -> dict[str, Any]:
  90. return self.cfg.get("dataset", {})
  91. def _dataset_name(self) -> str:
  92. dataset_cfg = self._dataset_cfg()
  93. dataset_name = dataset_cfg.get("dataset_name") or dataset_cfg.get("name")
  94. if not dataset_name:
  95. raise ValueError("dataset.dataset_name is required.")
  96. return str(dataset_name)
  97. def _dataset_root(self) -> str:
  98. dataset_cfg = self._dataset_cfg()
  99. root = dataset_cfg.get("root")
  100. if not root:
  101. raise ValueError("dataset.root is required.")
  102. return str(root)
  103. def _image_size(self) -> tuple[int, int]:
  104. dataset_cfg = self._dataset_cfg()
  105. image_size = dataset_cfg.get("image_size")
  106. if image_size is None:
  107. raise ValueError("dataset.image_size is required.")
  108. return int(image_size[0]), int(image_size[1])
  109. def _build_resize_transform(self, *, mode: str) -> Any:
  110. height, width = self._image_size()
  111. interpolation_mode = "bilinear" if mode == "image" else "nearest"
  112. def _transform(tensor: torch.Tensor) -> torch.Tensor:
  113. resized = F.interpolate(
  114. tensor.unsqueeze(0),
  115. size=(height, width),
  116. mode=interpolation_mode,
  117. align_corners=False if interpolation_mode != "nearest" else None,
  118. )
  119. return resized.squeeze(0)
  120. return _transform
  121. def _build_segmentation_loader(
  122. self,
  123. *,
  124. split: str,
  125. batch_size: int,
  126. shuffle: bool,
  127. split_file: str | None = None,
  128. ):
  129. dataset_cfg = self._dataset_cfg()
  130. train_cfg = self.cfg.get("train", {})
  131. num_workers = max(0, int(train_cfg.get("num_workers", 0)))
  132. persistent_workers = bool(train_cfg.get("persistent_workers", False)) if num_workers > 0 else False
  133. loader = build_dataloader(
  134. dataset_name=self._dataset_name(),
  135. root=self._dataset_root(),
  136. split=split,
  137. split_file=split_file,
  138. batch_size=batch_size,
  139. shuffle=shuffle,
  140. num_workers=num_workers,
  141. image_transform=self._build_resize_transform(mode="image"),
  142. mask_transform=self._build_resize_transform(mode="mask"),
  143. pin_memory=bool(train_cfg.get("pin_memory", self.device.type == "cuda")),
  144. persistent_workers=persistent_workers,
  145. prefetch_factor=train_cfg.get("prefetch_factor") if num_workers > 0 else None,
  146. )
  147. return loader
  148. def _build_val_loader(
  149. self,
  150. *,
  151. batch_size: int,
  152. shuffle: bool = False,
  153. ):
  154. dataset_cfg = self._dataset_cfg()
  155. val_split = dataset_cfg.get("val_split", "val")
  156. if val_split is None:
  157. return None
  158. return self._build_segmentation_loader(
  159. split=str(val_split),
  160. split_file=dataset_cfg.get("val_split_file"),
  161. batch_size=batch_size,
  162. shuffle=shuffle,
  163. )
  164. def _checkpoint_cfg(self) -> dict[str, Any]:
  165. return self.cfg.get("checkpoint", {})
  166. def _logging_cfg(self) -> dict[str, Any]:
  167. return self.cfg.get("logging", {})
  168. def _validation_cfg(self) -> dict[str, Any]:
  169. return self.cfg.get("validation", {})
  170. def _checkpoint_enabled(self) -> bool:
  171. return bool(self._checkpoint_cfg().get("save", True))
  172. def _best_mode(self) -> str:
  173. return str(self._checkpoint_cfg().get("monitor_mode", "min"))
  174. def _is_better_metric(self, metric: float) -> bool:
  175. if self.best_metric is None:
  176. return True
  177. if self._best_mode() == "max":
  178. return metric > self.best_metric
  179. return metric < self.best_metric
  180. def _save_checkpoint(self, filename: str, state: dict[str, Any]) -> Path | None:
  181. if not self._checkpoint_enabled():
  182. return None
  183. path = self.output_dir / filename
  184. torch.save(state, path)
  185. return path
  186. def _resume_checkpoint_path(self) -> Path | None:
  187. resume_path = self._checkpoint_cfg().get("resume")
  188. if not resume_path:
  189. return None
  190. path = Path(str(resume_path))
  191. if not path.is_absolute():
  192. path = Path.cwd() / path
  193. return path
  194. def _maybe_resume(
  195. self,
  196. *,
  197. module_map: dict[str, Any],
  198. optimizer: Any | None = None,
  199. scheduler: Any | None = None,
  200. ) -> dict[str, Any] | None:
  201. path = self._resume_checkpoint_path()
  202. if path is None:
  203. return None
  204. if not path.exists():
  205. raise FileNotFoundError(f"Resume checkpoint not found: {path}")
  206. checkpoint = torch.load(path, map_location="cpu")
  207. strict = bool(self._checkpoint_cfg().get("resume_strict", True))
  208. for key, module in module_map.items():
  209. if module is None:
  210. continue
  211. state_dict = checkpoint.get(key)
  212. if state_dict is not None:
  213. module.load_state_dict(state_dict, strict=strict)
  214. if optimizer is not None and checkpoint.get("optimizer") is not None:
  215. optimizer.load_state_dict(checkpoint["optimizer"])
  216. if scheduler is not None and checkpoint.get("scheduler") is not None:
  217. scheduler.load_state_dict(checkpoint["scheduler"])
  218. if checkpoint.get("grad_scaler") is not None:
  219. self.grad_scaler.load_state_dict(checkpoint["grad_scaler"])
  220. if checkpoint.get("best_metric") is not None:
  221. self.best_metric = float(checkpoint["best_metric"])
  222. elif checkpoint.get("metrics") is not None:
  223. monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
  224. monitor_value = checkpoint["metrics"].get(f"val_{monitor_name}")
  225. if monitor_value is None:
  226. monitor_value = checkpoint["metrics"].get(monitor_name)
  227. if monitor_value is not None:
  228. self.best_metric = float(monitor_value)
  229. if checkpoint.get("no_improve_epochs") is not None:
  230. self.no_improve_epochs = int(checkpoint["no_improve_epochs"])
  231. if bool(self._checkpoint_cfg().get("resume_training", True)):
  232. self.start_epoch = int(checkpoint.get("epoch", -1)) + 1
  233. return checkpoint
  234. def _validation_enabled(self) -> bool:
  235. return bool(self._validation_cfg().get("enabled", True))
  236. def _validation_interval(self) -> int:
  237. return max(1, int(self._validation_cfg().get("interval", 1)))
  238. def _should_validate(self, epoch: int) -> bool:
  239. return self._validation_enabled() and ((epoch + 1) % self._validation_interval() == 0)
  240. def _metric_task_mode(self) -> str:
  241. validation_cfg = self._validation_cfg()
  242. metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
  243. if isinstance(metrics_cfg, dict):
  244. return str(metrics_cfg.get("task_mode", "binary"))
  245. return "binary"
  246. def _metric_threshold(self) -> float:
  247. validation_cfg = self._validation_cfg()
  248. threshold = validation_cfg.get("threshold", 0.5)
  249. return float(threshold)
  250. def _build_validation_metrics(self) -> dict[str, Any]:
  251. validation_cfg = self._validation_cfg()
  252. metrics_cfg = validation_cfg.get("metrics", self.cfg.get("metrics"))
  253. if metrics_cfg is None:
  254. return {}
  255. return build_metrics(metrics_cfg)
  256. def _early_stopping_enabled(self) -> bool:
  257. return bool(self._validation_cfg().get("early_stopping", False))
  258. def _early_stopping_patience(self) -> int:
  259. return max(1, int(self._validation_cfg().get("early_stopping_patience", 10)))
  260. def _early_stopping_min_delta(self) -> float:
  261. return float(self._validation_cfg().get("early_stopping_min_delta", 0.0))
  262. def _update_validation_metrics(
  263. self,
  264. metrics: dict[str, Any],
  265. *,
  266. logits: torch.Tensor,
  267. target: torch.Tensor,
  268. ) -> None:
  269. if not metrics:
  270. return
  271. update_metrics(
  272. metrics,
  273. logits,
  274. target,
  275. task_mode=self._metric_task_mode(),
  276. threshold=self._metric_threshold(),
  277. num_classes=int(self._dataset_cfg().get("num_classes", 1)),
  278. )
  279. def _compute_validation_metric_values(self, metrics: dict[str, Any]) -> dict[str, float]:
  280. if not metrics:
  281. return {}
  282. values = compute_metrics(metrics)
  283. reset_metrics(metrics)
  284. return values
  285. def _init_swanlab(self) -> None:
  286. logging_cfg = self._logging_cfg()
  287. if not bool(logging_cfg.get("use_swanlab", False)):
  288. return
  289. if swanlab is None:
  290. print("SwanLab is not installed. Logging will continue without SwanLab.")
  291. return
  292. run_name = logging_cfg.get("experiment_name") or self.output_dir.name
  293. self.swanlab_run = swanlab.init(
  294. project=logging_cfg.get("project", "X_SSL_Net"),
  295. name=run_name,
  296. config=self.cfg,
  297. mode=logging_cfg.get("swanlab_mode"),
  298. )
  299. def _log_metrics(self, metrics: dict[str, float], *, step: int) -> None:
  300. if self.swanlab_run is None:
  301. return
  302. swanlab.log(metrics, step=step)
  303. def _close_loggers(self) -> None:
  304. if self.swanlab_run is not None:
  305. swanlab.finish()
  306. self.swanlab_run = None
  307. def _log_interval(self) -> int:
  308. return max(1, int(self._logging_cfg().get("log_interval", 20)))
  309. def _grad_clip_cfg(self) -> dict[str, Any]:
  310. cfg = self.cfg.get("train", {}).get("grad_clip", {})
  311. return cfg if isinstance(cfg, dict) else {}
  312. def _grad_clip_enabled(self) -> bool:
  313. return bool(self._grad_clip_cfg().get("enabled", False))
  314. def _clip_gradients(self, module: nn.Module | None) -> float | None:
  315. if module is None or not self._grad_clip_enabled():
  316. return None
  317. cfg = self._grad_clip_cfg()
  318. max_norm = float(cfg.get("max_norm", 1.0))
  319. norm_type = float(cfg.get("norm_type", 2.0))
  320. params = [param for param in module.parameters() if param.requires_grad and param.grad is not None]
  321. if not params:
  322. return None
  323. total_norm = torch.nn.utils.clip_grad_norm_(params, max_norm=max_norm, norm_type=norm_type)
  324. return float(total_norm.detach().cpu() if isinstance(total_norm, torch.Tensor) else total_norm)
  325. def _current_lrs(self, optimizer: Any | None) -> list[float]:
  326. if optimizer is None:
  327. return []
  328. return [float(group.get("lr", 0.0)) for group in optimizer.param_groups]
  329. @staticmethod
  330. def _count_parameters(module: nn.Module | None) -> dict[str, int]:
  331. if module is None:
  332. return {"total": 0, "trainable": 0}
  333. total = sum(param.numel() for param in module.parameters())
  334. trainable = sum(param.numel() for param in module.parameters() if param.requires_grad)
  335. return {"total": int(total), "trainable": int(trainable)}
  336. @staticmethod
  337. def _loader_summary(loader: Any | None) -> dict[str, Any] | None:
  338. if loader is None:
  339. return None
  340. dataset = getattr(loader, "dataset", None)
  341. return {
  342. "dataset_size": len(dataset) if dataset is not None else None,
  343. "num_batches": len(loader),
  344. "batch_size": getattr(loader, "batch_size", None),
  345. "num_workers": getattr(loader, "num_workers", None),
  346. "pin_memory": getattr(loader, "pin_memory", None),
  347. "persistent_workers": getattr(loader, "persistent_workers", None),
  348. "prefetch_factor": getattr(loader, "prefetch_factor", None),
  349. "drop_last": getattr(loader, "drop_last", None),
  350. }
  351. def _training_setup_summary(
  352. self,
  353. *,
  354. model_map: dict[str, nn.Module | None],
  355. loader_map: dict[str, Any | None],
  356. optimizer: Any | None = None,
  357. scheduler: Any | None = None,
  358. ) -> dict[str, Any]:
  359. return {
  360. "trainer": self.cfg.get("trainer", {}).get("name"),
  361. "device": str(self.device),
  362. "amp_enabled": self._amp_enabled(),
  363. "output_dir": str(self.output_dir),
  364. "start_epoch": self.start_epoch,
  365. "train": self.cfg.get("train", {}),
  366. "dataset": self.cfg.get("dataset", {}),
  367. "model": self.cfg.get("model", {}),
  368. "optimizer": self.cfg.get("optimizer", {}),
  369. "scheduler": self.cfg.get("scheduler"),
  370. "current_lrs": self._current_lrs(optimizer),
  371. "validation": self.cfg.get("validation", {}),
  372. "checkpoint": self.cfg.get("checkpoint", {}),
  373. "logging": self.cfg.get("logging", {}),
  374. "model_parameters": {
  375. name: self._count_parameters(module)
  376. for name, module in model_map.items()
  377. },
  378. "loaders": {
  379. name: self._loader_summary(loader)
  380. for name, loader in loader_map.items()
  381. },
  382. "cuda": {
  383. "available": torch.cuda.is_available(),
  384. "device_name": torch.cuda.get_device_name(self.device) if self.device.type == "cuda" else None,
  385. "device_count": torch.cuda.device_count(),
  386. },
  387. }
  388. def _print_training_setup(
  389. self,
  390. *,
  391. model_map: dict[str, nn.Module | None],
  392. loader_map: dict[str, Any | None],
  393. optimizer: Any | None = None,
  394. scheduler: Any | None = None,
  395. ) -> None:
  396. if not bool(self._logging_cfg().get("print_training_setup", True)):
  397. return
  398. summary = self._training_setup_summary(
  399. model_map=model_map,
  400. loader_map=loader_map,
  401. optimizer=optimizer,
  402. scheduler=scheduler,
  403. )
  404. print("========== TRAINING SETUP ==========")
  405. pprint.pprint(summary, sort_dicts=False, width=120)
  406. print("======== END TRAINING SETUP ========")
  407. def _gpu_memory_mb(self) -> float:
  408. if self.device.type != "cuda" or not torch.cuda.is_available():
  409. return 0.0
  410. return float(torch.cuda.max_memory_allocated(device=self.device) / (1024 ** 2))
  411. def _performance_snapshot(
  412. self,
  413. *,
  414. epoch: int,
  415. step: int,
  416. num_steps: int,
  417. data_time: float,
  418. iter_time: float,
  419. metrics: dict[str, float],
  420. prefix: str = "train",
  421. ) -> dict[str, float | int]:
  422. snapshot: dict[str, float | int] = {
  423. "epoch": epoch,
  424. "step": step,
  425. "num_steps": num_steps,
  426. "data_time": data_time,
  427. "iter_time": iter_time,
  428. "gpu_memory_mb": self._gpu_memory_mb(),
  429. }
  430. lrs = self._current_lrs(getattr(self, "optimizer", None))
  431. if lrs:
  432. snapshot["lr"] = lrs[0]
  433. for key, value in metrics.items():
  434. snapshot[f"{prefix}_{key}"] = value
  435. return snapshot
  436. def _maybe_log_step(
  437. self,
  438. *,
  439. epoch: int,
  440. step: int,
  441. num_steps: int,
  442. data_time: float,
  443. iter_time: float,
  444. metrics: dict[str, float],
  445. prefix: str = "train",
  446. ) -> None:
  447. if step % self._log_interval() != 0 and step != num_steps:
  448. return
  449. snapshot = self._performance_snapshot(
  450. epoch=epoch,
  451. step=step,
  452. num_steps=num_steps,
  453. data_time=data_time,
  454. iter_time=iter_time,
  455. metrics=metrics,
  456. prefix=prefix,
  457. )
  458. print(snapshot)
  459. log_metrics = {
  460. f"{prefix}/{key}": value
  461. for key, value in metrics.items()
  462. }
  463. log_metrics.update(
  464. {
  465. f"{prefix}/data_time": data_time,
  466. f"{prefix}/iter_time": iter_time,
  467. f"{prefix}/gpu_memory_mb": float(snapshot["gpu_memory_mb"]),
  468. }
  469. )
  470. if "lr" in snapshot:
  471. log_metrics[f"{prefix}/lr"] = float(snapshot["lr"])
  472. self._log_metrics(log_metrics, step=epoch * max(1, num_steps) + step)
  473. @staticmethod
  474. def _average_metric_sums(metric_sums: dict[str, float], steps: int) -> dict[str, float]:
  475. if steps <= 0:
  476. return {}
  477. return {key: value / steps for key, value in metric_sums.items()}
  478. def _base_checkpoint_state(self, *, epoch: int, metrics: dict[str, float] | None = None) -> dict[str, Any]:
  479. state = {
  480. "epoch": epoch,
  481. "cfg": self.cfg,
  482. "metrics": metrics or {},
  483. "grad_scaler": self.grad_scaler.state_dict(),
  484. "no_improve_epochs": self.no_improve_epochs,
  485. }
  486. return state
  487. def _finalize_epoch(
  488. self,
  489. *,
  490. epoch: int,
  491. train_metrics: dict[str, float],
  492. val_metrics: dict[str, float] | None,
  493. checkpoint_state: dict[str, Any],
  494. ) -> tuple[dict[str, Any], bool]:
  495. merged_metrics = dict(train_metrics)
  496. if val_metrics is not None:
  497. merged_metrics.update({f"val_{key}": value for key, value in val_metrics.items()})
  498. improved = False
  499. if val_metrics is not None:
  500. monitor_name = str(self._checkpoint_cfg().get("monitor", "total"))
  501. if monitor_name not in val_metrics:
  502. raise KeyError(f"Checkpoint monitor '{monitor_name}' not found in val metrics.")
  503. monitor_value = float(val_metrics[monitor_name])
  504. delta = self._early_stopping_min_delta()
  505. previous_best = self.best_metric
  506. is_better = self._is_better_metric(monitor_value)
  507. if previous_best is not None and self._best_mode() == "max":
  508. is_better = monitor_value > (previous_best + delta)
  509. elif previous_best is not None and self._best_mode() == "min":
  510. is_better = monitor_value < (previous_best - delta)
  511. if is_better:
  512. self.best_metric = monitor_value
  513. self.no_improve_epochs = 0
  514. improved = True
  515. best_state = dict(checkpoint_state)
  516. best_state.update(
  517. self._base_checkpoint_state(
  518. epoch=epoch,
  519. metrics=merged_metrics,
  520. )
  521. )
  522. best_state["best_metric"] = self.best_metric
  523. self._save_checkpoint("best.pth", best_state)
  524. else:
  525. self.no_improve_epochs += 1
  526. save_last = bool(self._checkpoint_cfg().get("save_last", True))
  527. if save_last:
  528. last_state = dict(checkpoint_state)
  529. last_state.update(self._base_checkpoint_state(epoch=epoch, metrics=merged_metrics))
  530. if self.best_metric is not None:
  531. last_state["best_metric"] = self.best_metric
  532. self._save_checkpoint("last.pth", last_state)
  533. summary = {"epoch": epoch}
  534. summary.update(train_metrics)
  535. if val_metrics is not None:
  536. summary.update({f"val_{key}": value for key, value in val_metrics.items()})
  537. if self.best_metric is not None:
  538. summary["best_metric"] = float(self.best_metric)
  539. summary["no_improve_epochs"] = self.no_improve_epochs
  540. lrs = self._current_lrs(getattr(self, "optimizer", None))
  541. if lrs:
  542. summary["lr"] = lrs[0]
  543. self._log_metrics(summary, step=epoch)
  544. should_stop = False
  545. if val_metrics is not None and self._early_stopping_enabled():
  546. should_stop = self.no_improve_epochs >= self._early_stopping_patience()
  547. summary["early_stop"] = should_stop
  548. summary["improved"] = improved
  549. return summary, should_stop
  550. @abstractmethod
  551. def build(self) -> None:
  552. """
  553. 创建模型、优化器、数据加载器等运行所需对象。
  554. """
  555. @abstractmethod
  556. def train(self) -> None:
  557. """
  558. 执行完整训练流程。
  559. """