| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146 |
- from __future__ import annotations
- from typing import Any
- from torch import nn
- from torch.optim import Adam, AdamW, Optimizer
- from torch.optim.lr_scheduler import (
- ConstantLR,
- CosineAnnealingLR,
- LinearLR,
- MultiStepLR,
- PolynomialLR,
- SequentialLR,
- )
- OPTIMIZER_REGISTRY = {
- "adam": Adam,
- "adamw": AdamW,
- }
- SCHEDULER_REGISTRY = {
- "linear": LinearLR,
- "cosine": CosineAnnealingLR,
- "poly": PolynomialLR,
- "multistep": MultiStepLR,
- "constant": ConstantLR,
- }
- DEFAULT_OPTIM_CONFIG = {
- "optimizer": {
- "name": "adamw",
- "lr": 5e-5,
- "weight_decay": 0.05,
- "betas": (0.9, 0.999),
- },
- "scheduler": {
- "name": "cosine",
- "warmup": {
- "name": "linear",
- "params": {
- "start_factor": 0.1,
- "total_iters": 10,
- },
- },
- "params": {
- "T_max": 90,
- "eta_min": 1e-6,
- },
- },
- }
- def _trainable_parameters(model: nn.Module) -> list[nn.Parameter]:
- return [param for param in model.parameters() if param.requires_grad]
- def build_optimizer(model: nn.Module, config: dict[str, Any]) -> Optimizer:
- """Build a PyTorch optimizer from a yaml-style config dict."""
- if not isinstance(config, dict) or not config:
- raise ValueError("Optimizer config must be a non-empty dict.")
- name = config.get("name", "adamw")
- if not isinstance(name, str):
- raise ValueError("Optimizer config field 'name' must be a string.")
- optimizer_cls = OPTIMIZER_REGISTRY.get(name.lower())
- if optimizer_cls is None:
- raise ValueError(
- f"Unsupported optimizer '{name}'. Expected one of: {', '.join(OPTIMIZER_REGISTRY)}."
- )
- if "lr" not in config:
- raise ValueError("Optimizer config must provide 'lr'.")
- kwargs = {key: value for key, value in config.items() if key != "name"}
- return optimizer_cls(_trainable_parameters(model), **kwargs)
- def _build_single_scheduler(optimizer: Optimizer, name: str, params: dict[str, Any] | None):
- scheduler_cls = SCHEDULER_REGISTRY.get(name.lower())
- if scheduler_cls is None:
- raise ValueError(
- f"Unsupported scheduler '{name}'. Expected one of: {', '.join(SCHEDULER_REGISTRY)}."
- )
- scheduler_params = {} if params is None else dict(params)
- if not isinstance(scheduler_params, dict):
- raise ValueError("Scheduler config field 'params' must be a dict if provided.")
- return scheduler_cls(optimizer, **scheduler_params)
- def build_scheduler(optimizer: Optimizer, config: dict[str, Any] | None):
- """Build a PyTorch scheduler from a yaml-style config dict."""
- if config is None:
- return None
- if not isinstance(config, dict) or not config:
- raise ValueError("Scheduler config must be a non-empty dict.")
- name = config.get("name", "cosine")
- if not isinstance(name, str):
- raise ValueError("Scheduler config field 'name' must be a string.")
- main_scheduler = _build_single_scheduler(optimizer, name, config.get("params"))
- warmup = config.get("warmup")
- if warmup is None:
- return main_scheduler
- if not isinstance(warmup, dict) or not warmup:
- raise ValueError("Scheduler config field 'warmup' must be a non-empty dict if provided.")
- warmup_name = warmup.get("name", "linear")
- if not isinstance(warmup_name, str):
- raise ValueError("Warmup scheduler field 'name' must be a string.")
- warmup_params = warmup.get("params")
- if warmup_params is None:
- raise ValueError("Warmup scheduler must provide a 'params' dict.")
- if not isinstance(warmup_params, dict):
- raise ValueError("Warmup scheduler field 'params' must be a dict.")
- total_iters = warmup_params.get("total_iters")
- if not isinstance(total_iters, int) or total_iters <= 0:
- raise ValueError("Warmup scheduler requires integer 'total_iters' > 0.")
- warmup_scheduler = _build_single_scheduler(optimizer, warmup_name, warmup_params)
- return SequentialLR(
- optimizer,
- schedulers=[warmup_scheduler, main_scheduler],
- milestones=[total_iters],
- )
- __all__ = [
- "DEFAULT_OPTIM_CONFIG",
- "OPTIMIZER_REGISTRY",
- "SCHEDULER_REGISTRY",
- "build_optimizer",
- "build_scheduler",
- ]
|