scheduler.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import torch
  2. from timm.scheduler.cosine_lr import CosineLRScheduler
  3. from timm.scheduler.tanh_lr import TanhLRScheduler
  4. from timm.scheduler.step_lr import StepLRScheduler
  5. from timm.scheduler.plateau_lr import PlateauLRScheduler
  6. def get_scheduler(cfg, optimizer):
  7. kwargs = {k: v for k, v in cfg.trainer.scheduler_kwargs.items()}
  8. name = kwargs.pop('name')
  9. use_iters = kwargs.pop('use_iters')
  10. if getattr(cfg.trainer, 'epoch_full', None):
  11. cfg.trainer.iter_full = cfg.data.train_size * cfg.trainer.epoch_full
  12. else:
  13. cfg.trainer.epoch_full = cfg.trainer.iter_full / cfg.data.train_size
  14. if kwargs['warmup_iters'] > -1:
  15. t_initial = cfg.trainer.iter_full
  16. warmup_t = kwargs['warmup_iters']
  17. decay_t = kwargs['decay_iters']
  18. patience_t = kwargs['patience_iters']
  19. if not use_iters:
  20. t_initial, warmup_t, decay_t, patience_t = [t / cfg.data.train_size for t in [t_initial, warmup_t, decay_t, patience_t]]
  21. cfg.trainer.iter_full += (kwargs['cooldown_iters'] + patience_t)
  22. cfg.trainer.epoch_full = cfg.trainer.iter_full / cfg.data.train_size
  23. elif kwargs['warmup_epochs'] > -1:
  24. t_initial = cfg.trainer.epoch_full
  25. warmup_t = kwargs['warmup_epochs']
  26. decay_t = kwargs['decay_epochs']
  27. patience_t = kwargs['patience_epochs']
  28. if use_iters:
  29. t_initial, warmup_t, decay_t, patience_t = [t * cfg.data.train_size for t in [t_initial, warmup_t, decay_t, patience_t]]
  30. cfg.trainer.epoch_full += (kwargs['cooldown_epochs'] + patience_t)
  31. cfg.trainer.iter_full = cfg.trainer.epoch_full * cfg.data.train_size
  32. else:
  33. raise Exception("invalid \'warmup_iters\' and \'warmup_epochs\'")
  34. if kwargs.get('lr_noise', None) is not None:
  35. lr_noise = kwargs.get('lr_noise')
  36. if isinstance(lr_noise, (list, tuple)):
  37. noise_range_t = [n * t_initial for n in lr_noise]
  38. if len(noise_range_t) == 1:
  39. noise_range_t = noise_range_t[0]
  40. else:
  41. noise_range_t = lr_noise * t_initial
  42. else:
  43. noise_range_t = None
  44. kwargs_common = dict(optimizer=optimizer,
  45. warmup_lr_init=kwargs['warmup_lr'],
  46. warmup_t=warmup_t,
  47. noise_pct=kwargs.get('noise_pct', 0.67),
  48. noise_std=kwargs.get('noise_std', 1.),
  49. noise_seed=kwargs.get('noise_seed', 42),
  50. noise_range_t=noise_range_t,
  51. )
  52. if name == 'cosine':
  53. lr_scheduler = CosineLRScheduler(
  54. **kwargs_common,
  55. t_initial=t_initial,
  56. cycle_mul=kwargs.get('lr_cycle_mul', 1.),
  57. lr_min=kwargs['lr_min'],
  58. cycle_decay=kwargs['cycle_decay'],
  59. cycle_limit=kwargs.get('lr_cycle_limit', 1),
  60. t_in_epochs=True,
  61. )
  62. elif name == 'tanh':
  63. lr_scheduler = TanhLRScheduler(
  64. **kwargs_common,
  65. t_initial=t_initial,
  66. cycle_mul=kwargs.get('lr_cycle_mul', 1.),
  67. lr_min=kwargs['lr_min'],
  68. cycle_limit=kwargs.get('lr_cycle_limit', 1),
  69. t_in_epochs=True,
  70. )
  71. elif name == 'step':
  72. lr_scheduler = StepLRScheduler(
  73. **kwargs_common,
  74. decay_t=decay_t,
  75. decay_rate=kwargs['decay_rate'],
  76. )
  77. elif name == 'plateau':
  78. mode = 'min' if 'loss' in kwargs.get('eval_metric', '') else 'max'
  79. lr_scheduler = PlateauLRScheduler(
  80. **kwargs_common,
  81. decay_rate=kwargs['decay_rate'],
  82. patience_t=kwargs['patience_iters'],
  83. lr_min=kwargs['lr_min'],
  84. mode=mode,
  85. cooldown_t=0,
  86. )
  87. else:
  88. raise Exception(f'invalid scheduler: {name}')
  89. return lr_scheduler