from __future__ import annotations from typing import Any from torch import nn from torch.optim import Adam, AdamW, Optimizer from torch.optim.lr_scheduler import ( ConstantLR, CosineAnnealingLR, LinearLR, MultiStepLR, PolynomialLR, SequentialLR, ) OPTIMIZER_REGISTRY = { "adam": Adam, "adamw": AdamW, } SCHEDULER_REGISTRY = { "linear": LinearLR, "cosine": CosineAnnealingLR, "poly": PolynomialLR, "multistep": MultiStepLR, "constant": ConstantLR, } DEFAULT_OPTIM_CONFIG = { "optimizer": { "name": "adamw", "lr": 5e-5, "weight_decay": 0.05, "betas": (0.9, 0.999), }, "scheduler": { "name": "cosine", "warmup": { "name": "linear", "params": { "start_factor": 0.1, "total_iters": 10, }, }, "params": { "T_max": 90, "eta_min": 1e-6, }, }, } def _trainable_parameters(model: nn.Module) -> list[nn.Parameter]: return [param for param in model.parameters() if param.requires_grad] def build_optimizer(model: nn.Module, config: dict[str, Any]) -> Optimizer: """Build a PyTorch optimizer from a yaml-style config dict.""" if not isinstance(config, dict) or not config: raise ValueError("Optimizer config must be a non-empty dict.") name = config.get("name", "adamw") if not isinstance(name, str): raise ValueError("Optimizer config field 'name' must be a string.") optimizer_cls = OPTIMIZER_REGISTRY.get(name.lower()) if optimizer_cls is None: raise ValueError( f"Unsupported optimizer '{name}'. Expected one of: {', '.join(OPTIMIZER_REGISTRY)}." ) if "lr" not in config: raise ValueError("Optimizer config must provide 'lr'.") kwargs = {key: value for key, value in config.items() if key != "name"} return optimizer_cls(_trainable_parameters(model), **kwargs) def _build_single_scheduler(optimizer: Optimizer, name: str, params: dict[str, Any] | None): scheduler_cls = SCHEDULER_REGISTRY.get(name.lower()) if scheduler_cls is None: raise ValueError( f"Unsupported scheduler '{name}'. Expected one of: {', '.join(SCHEDULER_REGISTRY)}." ) scheduler_params = {} if params is None else dict(params) if not isinstance(scheduler_params, dict): raise ValueError("Scheduler config field 'params' must be a dict if provided.") return scheduler_cls(optimizer, **scheduler_params) def build_scheduler(optimizer: Optimizer, config: dict[str, Any] | None): """Build a PyTorch scheduler from a yaml-style config dict.""" if config is None: return None if not isinstance(config, dict) or not config: raise ValueError("Scheduler config must be a non-empty dict.") name = config.get("name", "cosine") if not isinstance(name, str): raise ValueError("Scheduler config field 'name' must be a string.") main_scheduler = _build_single_scheduler(optimizer, name, config.get("params")) warmup = config.get("warmup") if warmup is None: return main_scheduler if not isinstance(warmup, dict) or not warmup: raise ValueError("Scheduler config field 'warmup' must be a non-empty dict if provided.") warmup_name = warmup.get("name", "linear") if not isinstance(warmup_name, str): raise ValueError("Warmup scheduler field 'name' must be a string.") warmup_params = warmup.get("params") if warmup_params is None: raise ValueError("Warmup scheduler must provide a 'params' dict.") if not isinstance(warmup_params, dict): raise ValueError("Warmup scheduler field 'params' must be a dict.") total_iters = warmup_params.get("total_iters") if not isinstance(total_iters, int) or total_iters <= 0: raise ValueError("Warmup scheduler requires integer 'total_iters' > 0.") warmup_scheduler = _build_single_scheduler(optimizer, warmup_name, warmup_params) return SequentialLR( optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[total_iters], ) __all__ = [ "DEFAULT_OPTIM_CONFIG", "OPTIMIZER_REGISTRY", "SCHEDULER_REGISTRY", "build_optimizer", "build_scheduler", ]