lr_scheduler.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import bisect
  8. import torch
  9. from timm.scheduler.cosine_lr import CosineLRScheduler
  10. from timm.scheduler.step_lr import StepLRScheduler
  11. from timm.scheduler.scheduler import Scheduler
  12. import timm
  13. if timm.__version__ != "0.4.12":
  14. from .cosine_lr import CosineLRScheduler
  15. def build_scheduler(config, optimizer, n_iter_per_epoch):
  16. num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
  17. warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
  18. decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
  19. multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
  20. lr_scheduler = None
  21. if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
  22. lr_scheduler = CosineLRScheduler(
  23. optimizer,
  24. t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps,
  25. t_mul=1.,
  26. lr_min=config.TRAIN.MIN_LR,
  27. warmup_lr_init=config.TRAIN.WARMUP_LR,
  28. warmup_t=warmup_steps,
  29. cycle_limit=1,
  30. t_in_epochs=False,
  31. warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX,
  32. )
  33. elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
  34. lr_scheduler = LinearLRScheduler(
  35. optimizer,
  36. t_initial=num_steps,
  37. lr_min_rate=0.01,
  38. warmup_lr_init=config.TRAIN.WARMUP_LR,
  39. warmup_t=warmup_steps,
  40. t_in_epochs=False,
  41. )
  42. elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
  43. lr_scheduler = StepLRScheduler(
  44. optimizer,
  45. decay_t=decay_steps,
  46. decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
  47. warmup_lr_init=config.TRAIN.WARMUP_LR,
  48. warmup_t=warmup_steps,
  49. t_in_epochs=False,
  50. )
  51. elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
  52. lr_scheduler = MultiStepLRScheduler(
  53. optimizer,
  54. milestones=multi_steps,
  55. gamma=config.TRAIN.LR_SCHEDULER.GAMMA,
  56. warmup_lr_init=config.TRAIN.WARMUP_LR,
  57. warmup_t=warmup_steps,
  58. t_in_epochs=False,
  59. )
  60. return lr_scheduler
  61. class LinearLRScheduler(Scheduler):
  62. def __init__(self,
  63. optimizer: torch.optim.Optimizer,
  64. t_initial: int,
  65. lr_min_rate: float,
  66. warmup_t=0,
  67. warmup_lr_init=0.,
  68. t_in_epochs=True,
  69. noise_range_t=None,
  70. noise_pct=0.67,
  71. noise_std=1.0,
  72. noise_seed=42,
  73. initialize=True,
  74. ) -> None:
  75. super().__init__(
  76. optimizer, param_group_field="lr",
  77. noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
  78. initialize=initialize)
  79. self.t_initial = t_initial
  80. self.lr_min_rate = lr_min_rate
  81. self.warmup_t = warmup_t
  82. self.warmup_lr_init = warmup_lr_init
  83. self.t_in_epochs = t_in_epochs
  84. if self.warmup_t:
  85. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  86. super().update_groups(self.warmup_lr_init)
  87. else:
  88. self.warmup_steps = [1 for _ in self.base_values]
  89. def _get_lr(self, t):
  90. if t < self.warmup_t:
  91. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  92. else:
  93. t = t - self.warmup_t
  94. total_t = self.t_initial - self.warmup_t
  95. lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
  96. return lrs
  97. def get_epoch_values(self, epoch: int):
  98. if self.t_in_epochs:
  99. return self._get_lr(epoch)
  100. else:
  101. return None
  102. def get_update_values(self, num_updates: int):
  103. if not self.t_in_epochs:
  104. return self._get_lr(num_updates)
  105. else:
  106. return None
  107. class MultiStepLRScheduler(Scheduler):
  108. def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:
  109. super().__init__(optimizer, param_group_field="lr")
  110. self.milestones = milestones
  111. self.gamma = gamma
  112. self.warmup_t = warmup_t
  113. self.warmup_lr_init = warmup_lr_init
  114. self.t_in_epochs = t_in_epochs
  115. if self.warmup_t:
  116. self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
  117. super().update_groups(self.warmup_lr_init)
  118. else:
  119. self.warmup_steps = [1 for _ in self.base_values]
  120. assert self.warmup_t <= min(self.milestones)
  121. def _get_lr(self, t):
  122. if t < self.warmup_t:
  123. lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
  124. else:
  125. lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values]
  126. return lrs
  127. def get_epoch_values(self, epoch: int):
  128. if self.t_in_epochs:
  129. return self._get_lr(epoch)
  130. else:
  131. return None
  132. def get_update_values(self, num_updates: int):
  133. if not self.t_in_epochs:
  134. return self._get_lr(num_updates)
  135. else:
  136. return None