optim.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. from __future__ import annotations
  2. from typing import Any
  3. from torch import nn
  4. from torch.optim import Adam, AdamW, Optimizer
  5. from torch.optim.lr_scheduler import (
  6. ConstantLR,
  7. CosineAnnealingLR,
  8. LinearLR,
  9. MultiStepLR,
  10. PolynomialLR,
  11. SequentialLR,
  12. )
  13. OPTIMIZER_REGISTRY = {
  14. "adam": Adam,
  15. "adamw": AdamW,
  16. }
  17. SCHEDULER_REGISTRY = {
  18. "linear": LinearLR,
  19. "cosine": CosineAnnealingLR,
  20. "poly": PolynomialLR,
  21. "multistep": MultiStepLR,
  22. "constant": ConstantLR,
  23. }
  24. DEFAULT_OPTIM_CONFIG = {
  25. "optimizer": {
  26. "name": "adamw",
  27. "lr": 5e-5,
  28. "weight_decay": 0.05,
  29. "betas": (0.9, 0.999),
  30. },
  31. "scheduler": {
  32. "name": "cosine",
  33. "warmup": {
  34. "name": "linear",
  35. "params": {
  36. "start_factor": 0.1,
  37. "total_iters": 10,
  38. },
  39. },
  40. "params": {
  41. "T_max": 90,
  42. "eta_min": 1e-6,
  43. },
  44. },
  45. }
  46. def _trainable_parameters(model: nn.Module) -> list[nn.Parameter]:
  47. return [param for param in model.parameters() if param.requires_grad]
  48. def build_optimizer(model: nn.Module, config: dict[str, Any]) -> Optimizer:
  49. """Build a PyTorch optimizer from a yaml-style config dict."""
  50. if not isinstance(config, dict) or not config:
  51. raise ValueError("Optimizer config must be a non-empty dict.")
  52. name = config.get("name", "adamw")
  53. if not isinstance(name, str):
  54. raise ValueError("Optimizer config field 'name' must be a string.")
  55. optimizer_cls = OPTIMIZER_REGISTRY.get(name.lower())
  56. if optimizer_cls is None:
  57. raise ValueError(
  58. f"Unsupported optimizer '{name}'. Expected one of: {', '.join(OPTIMIZER_REGISTRY)}."
  59. )
  60. if "lr" not in config:
  61. raise ValueError("Optimizer config must provide 'lr'.")
  62. kwargs = {key: value for key, value in config.items() if key != "name"}
  63. return optimizer_cls(_trainable_parameters(model), **kwargs)
  64. def _build_single_scheduler(optimizer: Optimizer, name: str, params: dict[str, Any] | None):
  65. scheduler_cls = SCHEDULER_REGISTRY.get(name.lower())
  66. if scheduler_cls is None:
  67. raise ValueError(
  68. f"Unsupported scheduler '{name}'. Expected one of: {', '.join(SCHEDULER_REGISTRY)}."
  69. )
  70. scheduler_params = {} if params is None else dict(params)
  71. if not isinstance(scheduler_params, dict):
  72. raise ValueError("Scheduler config field 'params' must be a dict if provided.")
  73. return scheduler_cls(optimizer, **scheduler_params)
  74. def build_scheduler(optimizer: Optimizer, config: dict[str, Any] | None):
  75. """Build a PyTorch scheduler from a yaml-style config dict."""
  76. if config is None:
  77. return None
  78. if not isinstance(config, dict) or not config:
  79. raise ValueError("Scheduler config must be a non-empty dict.")
  80. name = config.get("name", "cosine")
  81. if not isinstance(name, str):
  82. raise ValueError("Scheduler config field 'name' must be a string.")
  83. main_scheduler = _build_single_scheduler(optimizer, name, config.get("params"))
  84. warmup = config.get("warmup")
  85. if warmup is None:
  86. return main_scheduler
  87. if not isinstance(warmup, dict) or not warmup:
  88. raise ValueError("Scheduler config field 'warmup' must be a non-empty dict if provided.")
  89. warmup_name = warmup.get("name", "linear")
  90. if not isinstance(warmup_name, str):
  91. raise ValueError("Warmup scheduler field 'name' must be a string.")
  92. warmup_params = warmup.get("params")
  93. if warmup_params is None:
  94. raise ValueError("Warmup scheduler must provide a 'params' dict.")
  95. if not isinstance(warmup_params, dict):
  96. raise ValueError("Warmup scheduler field 'params' must be a dict.")
  97. total_iters = warmup_params.get("total_iters")
  98. if not isinstance(total_iters, int) or total_iters <= 0:
  99. raise ValueError("Warmup scheduler requires integer 'total_iters' > 0.")
  100. warmup_scheduler = _build_single_scheduler(optimizer, warmup_name, warmup_params)
  101. return SequentialLR(
  102. optimizer,
  103. schedulers=[warmup_scheduler, main_scheduler],
  104. milestones=[total_iters],
  105. )
  106. __all__ = [
  107. "DEFAULT_OPTIM_CONFIG",
  108. "OPTIMIZER_REGISTRY",
  109. "SCHEDULER_REGISTRY",
  110. "build_optimizer",
  111. "build_scheduler",
  112. ]