cosine_lr.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. """ Cosine Scheduler
  2. Cosine LR schedule with warmup, cycle/restarts, noise.
  3. Hacked together by / Copyright 2020 Ross Wightman
  4. """
  5. import logging
  6. import math
  7. import numpy as np
  8. import torch
  9. from timm.scheduler.scheduler import Scheduler
  10. _logger = logging.getLogger(__name__)
  11. class CosineLRScheduler(Scheduler):
  12. """
  13. Cosine decay with restarts.
  14. This is described in the paper https://arxiv.org/abs/1608.03983.
  15. Inspiration from
  16. https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
  17. """
  18. def __init__(self,
  19. optimizer: torch.optim.Optimizer,
  20. t_initial: int,
  21. t_mul: float = 1.,
  22. lr_min: float = 0.,
  23. decay_rate: float = 1.,
  24. warmup_t=0,
  25. warmup_lr_init=0,
  26. warmup_prefix=False,
  27. cycle_limit=0,
  28. t_in_epochs=True,
  29. noise_range_t=None,
  30. noise_pct=0.67,
  31. noise_std=1.0,
  32. noise_seed=42,
  33. initialize=True) -> None:
  34. super().__init__(
  35. optimizer, param_group_field="lr",
  36. noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
  37. initialize=initialize)
  38. assert t_initial > 0
  39. assert lr_min >= 0
  40. if t_initial == 1 and t_mul == 1 and decay_rate == 1:
  41. _logger.warning("Cosine annealing scheduler will have no effect on the learning "
  42. "rate since t_initial = t_mul = eta_mul = 1.")
  43. self.t_initial = t_initial
  44. self.t_mul = t_mul
  45. self.lr_min = lr_min
  46. self.decay_rate = decay_rate
  47. self.cycle_limit = cycle_limit
  48. self.warmup_t = warmup_t
  49. self.warmup_lr_init = warmup_lr_init
  50. self.warmup_prefix = warmup_prefix
  51. self.t_in_epochs = t_in_epochs
  52. if self.warmup_t:
  53. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  54. super().update_groups(self.warmup_lr_init)
  55. else:
  56. self.warmup_steps = [1 for _ in self.base_values]
  57. def _get_lr(self, t):
  58. if t < self.warmup_t:
  59. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  60. else:
  61. if self.warmup_prefix:
  62. t = t - self.warmup_t
  63. if self.t_mul != 1:
  64. i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
  65. t_i = self.t_mul ** i * self.t_initial
  66. t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
  67. else:
  68. i = t // self.t_initial
  69. t_i = self.t_initial
  70. t_curr = t - (self.t_initial * i)
  71. gamma = self.decay_rate ** i
  72. lr_min = self.lr_min * gamma
  73. lr_max_values = [v * gamma for v in self.base_values]
  74. if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit):
  75. lrs = [
  76. lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
  77. ]
  78. else:
  79. lrs = [self.lr_min for _ in self.base_values]
  80. return lrs
  81. def get_epoch_values(self, epoch: int):
  82. if self.t_in_epochs:
  83. return self._get_lr(epoch)
  84. else:
  85. return None
  86. def get_update_values(self, num_updates: int):
  87. if not self.t_in_epochs:
  88. return self._get_lr(num_updates)
  89. else:
  90. return None
  91. def get_cycle_length(self, cycles=0):
  92. if not cycles:
  93. cycles = self.cycle_limit
  94. cycles = max(1, cycles)
  95. if self.t_mul == 1.0:
  96. return self.t_initial * cycles
  97. else:
  98. return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul)))