utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # --------------------------------------------------------
  2. # Modified By $@#Anonymous#@$
  3. # --------------------------------------------------------
  4. # Swin Transformer
  5. # Copyright (c) 2021 Microsoft
  6. # Licensed under The MIT License [see LICENSE for details]
  7. # Written by Ze Liu
  8. # --------------------------------------------------------
  9. import os
  10. from math import inf
  11. import torch
  12. import torch.distributed as dist
  13. from timm.utils import ModelEma as ModelEma
  14. def load_checkpoint_ema(config, model, optimizer, lr_scheduler, loss_scaler, logger, model_ema: ModelEma=None):
  15. logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
  16. if config.MODEL.RESUME.startswith('https'):
  17. checkpoint = torch.hub.load_state_dict_from_url(
  18. config.MODEL.RESUME, map_location='cpu', check_hash=True)
  19. else:
  20. checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
  21. if 'model' in checkpoint:
  22. msg = model.load_state_dict(checkpoint['model'], strict=False)
  23. logger.info(f"resuming model: {msg}")
  24. else:
  25. logger.warning(f"No 'model' found in {config.MODEL.RESUME}! ")
  26. if model_ema is not None:
  27. if 'model_ema' in checkpoint:
  28. msg = model_ema.ema.load_state_dict(checkpoint['model_ema'], strict=False)
  29. logger.info(f"resuming model_ema: {msg}")
  30. else:
  31. logger.warning(f"No 'model_ema' found in {config.MODEL.RESUME}! ")
  32. max_accuracy = 0.0
  33. max_accuracy_ema = 0.0
  34. if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
  35. optimizer.load_state_dict(checkpoint['optimizer'])
  36. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  37. config.defrost()
  38. config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
  39. config.freeze()
  40. if 'scaler' in checkpoint:
  41. loss_scaler.load_state_dict(checkpoint['scaler'])
  42. logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
  43. if 'max_accuracy' in checkpoint:
  44. max_accuracy = checkpoint['max_accuracy']
  45. if 'max_accuracy_ema' in checkpoint:
  46. max_accuracy_ema = checkpoint['max_accuracy_ema']
  47. del checkpoint
  48. torch.cuda.empty_cache()
  49. return max_accuracy, max_accuracy_ema
  50. def load_pretrained_ema(config, model, logger, model_ema: ModelEma=None):
  51. logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......")
  52. checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
  53. if 'model' in checkpoint:
  54. msg = model.load_state_dict(checkpoint['model'], strict=False)
  55. logger.warning(msg)
  56. logger.info(f"=> loaded 'model' successfully from '{config.MODEL.PRETRAINED}'")
  57. else:
  58. logger.warning(f"No 'model' found in {config.MODEL.PRETRAINED}! ")
  59. if model_ema is not None:
  60. if "model_ema" in checkpoint:
  61. logger.info(f"=> loading 'model_ema' separately...")
  62. key = "model_ema" if ("model_ema" in checkpoint) else "model"
  63. if key in checkpoint:
  64. msg = model_ema.ema.load_state_dict(checkpoint[key], strict=False)
  65. logger.warning(msg)
  66. logger.info(f"=> loaded '{key}' successfully from '{config.MODEL.PRETRAINED}' for model_ema")
  67. else:
  68. logger.warning(f"No '{key}' found in {config.MODEL.PRETRAINED}! ")
  69. del checkpoint
  70. torch.cuda.empty_cache()
  71. def save_checkpoint_ema(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, model_ema: ModelEma=None, max_accuracy_ema=None):
  72. save_state = {'model': model.state_dict(),
  73. 'optimizer': optimizer.state_dict(),
  74. 'lr_scheduler': lr_scheduler.state_dict(),
  75. 'max_accuracy': max_accuracy,
  76. 'scaler': loss_scaler.state_dict(),
  77. 'epoch': epoch,
  78. 'config': config}
  79. if model_ema is not None:
  80. save_state.update({'model_ema': model_ema.ema.state_dict(),
  81. 'max_accuray_ema': max_accuracy_ema})
  82. save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
  83. logger.info(f"{save_path} saving......")
  84. torch.save(save_state, save_path)
  85. logger.info(f"{save_path} saved !!!")
  86. def get_grad_norm(parameters, norm_type=2):
  87. if isinstance(parameters, torch.Tensor):
  88. parameters = [parameters]
  89. parameters = list(filter(lambda p: p.grad is not None, parameters))
  90. norm_type = float(norm_type)
  91. total_norm = 0
  92. for p in parameters:
  93. param_norm = p.grad.data.norm(norm_type)
  94. total_norm += param_norm.item() ** norm_type
  95. total_norm = total_norm ** (1. / norm_type)
  96. return total_norm
  97. def auto_resume_helper(output_dir):
  98. checkpoints = os.listdir(output_dir)
  99. checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
  100. print(f"All checkpoints founded in {output_dir}: {checkpoints}")
  101. if len(checkpoints) > 0:
  102. latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
  103. print(f"The latest checkpoint founded: {latest_checkpoint}")
  104. resume_file = latest_checkpoint
  105. else:
  106. resume_file = None
  107. return resume_file
  108. def reduce_tensor(tensor):
  109. rt = tensor.clone()
  110. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  111. rt /= dist.get_world_size()
  112. return rt
  113. def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor:
  114. if isinstance(parameters, torch.Tensor):
  115. parameters = [parameters]
  116. parameters = [p for p in parameters if p.grad is not None]
  117. norm_type = float(norm_type)
  118. if len(parameters) == 0:
  119. return torch.tensor(0.)
  120. device = parameters[0].grad.device
  121. if norm_type == inf:
  122. total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
  123. else:
  124. total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(),
  125. norm_type).to(device) for p in parameters]), norm_type)
  126. return total_norm
  127. class NativeScalerWithGradNormCount:
  128. state_dict_key = "amp_scaler"
  129. def __init__(self):
  130. self._scaler = torch.cuda.amp.GradScaler()
  131. def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
  132. self._scaler.scale(loss).backward(create_graph=create_graph)
  133. if update_grad:
  134. if clip_grad is not None:
  135. assert parameters is not None
  136. self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
  137. norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
  138. else:
  139. self._scaler.unscale_(optimizer)
  140. norm = ampscaler_get_grad_norm(parameters)
  141. self._scaler.step(optimizer)
  142. self._scaler.update()
  143. else:
  144. norm = None
  145. return norm
  146. def state_dict(self):
  147. return self._scaler.state_dict()
  148. def load_state_dict(self, state_dict):
  149. self._scaler.load_state_dict(state_dict)