cls.py 4.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. import copy
  3. import shutil
  4. import datetime
  5. import torch
  6. from util.util import makedirs, log_cfg, able, log_msg, get_log_terms, update_log_term, accuracy
  7. from util.net import trans_state_dict, print_networks, get_timepc, reduce_tensor
  8. from optim.scheduler import get_scheduler
  9. from data import get_loader
  10. from model import get_model
  11. from optim import get_optim
  12. from loss import get_loss_terms
  13. from timm.data import Mixup
  14. from torch.nn.parallel import DistributedDataParallel as NativeDDP
  15. try:
  16. from apex import amp
  17. from apex.parallel import DistributedDataParallel as ApexDDP
  18. from apex.parallel import convert_syncbn_model as ApexSyncBN
  19. except:
  20. from timm.layers.norm_act import convert_sync_batchnorm as ApexSyncBN
  21. from timm.layers.norm_act import convert_sync_batchnorm as TIMMSyncBN
  22. from timm.utils import dispatch_clip_grad
  23. from util.net import get_loss_scaler, get_autocast, distribute_bn
  24. from ._base_trainer import BaseTrainer
  25. from . import TRAINER
  26. @TRAINER.register_module
  27. class CLSTrainer(BaseTrainer):
  28. def __init__(self, cfg):
  29. super(CLSTrainer, self).__init__(cfg)
  30. def set_input(self, inputs):
  31. self.imgs = inputs['img'].cuda()
  32. self.targets = inputs['target'].cuda()
  33. self.bs = self.imgs.shape[0]
  34. def forward(self, net=None):
  35. net = net if net is not None else self.net
  36. self.outputs = net(self.imgs)
  37. if not isinstance(self.outputs, dict):
  38. self.outputs = {'out': self.outputs, 'out_kd': self.outputs}
  39. def backward_term(self, loss_term, optim):
  40. optim.zero_grad()
  41. if self.loss_scaler:
  42. self.loss_scaler(loss_term, optim, clip_grad=self.cfg.loss.clip_grad, parameters=self.net.parameters(), create_graph=self.cfg.loss.create_graph)
  43. else:
  44. loss_term.backward(retain_graph=self.cfg.loss.retain_graph)
  45. if self.cfg.loss.clip_grad is not None:
  46. dispatch_clip_grad(self.net.parameters(), value=self.cfg.loss.clip_grad)
  47. optim.step()
  48. def optimize_parameters(self):
  49. if self.mixup_fn is not None:
  50. self.imgs, self.targets = self.mixup_fn(self.imgs, self.targets)
  51. with self.amp_autocast():
  52. self.forward()
  53. nan_or_inf_out = 1. if torch.any(torch.isnan(self.outputs['out'])) or torch.any(torch.isinf(self.outputs['out'])) else 0.
  54. nan_or_inf_out = reduce_tensor(nan_or_inf_out, self.world_size, mode='sum', sum_avg=False).clone().detach().item()
  55. nan_or_inf_out = True if nan_or_inf_out > 0. else False
  56. if nan_or_inf_out:
  57. self.nan_or_inf_cnt += 1
  58. log_msg(self.logger, f'NaN or Inf Found, total {self.nan_or_inf_cnt} times')
  59. self.check_bn()
  60. loss_ce = self.loss_terms['CE'](self.outputs['out'], self.targets) if not nan_or_inf_out else 0
  61. loss_kd = (self.loss_terms['KD'](self.outputs['out_kd'], self.imgs) if self.loss_terms.get('KD', None) else 0) if not nan_or_inf_out else 0
  62. self.backward_term((loss_ce + loss_kd) if not nan_or_inf_out else (0 * self.outputs['out'][0, 0]), self.optim)
  63. update_log_term(self.log_terms.get('CE'), reduce_tensor(loss_ce, self.world_size).clone().detach().item(), 1, self.master)
  64. update_log_term(self.log_terms.get('KD'), reduce_tensor(loss_kd, self.world_size).clone().detach().item(), 1, self.master)
  65. self._update_ema()
  66. @torch.no_grad()
  67. def test(self):
  68. tops = self.test_net(self.net, name='net')
  69. self.is_best = True if len(self.topk_recorder['net_top1']) == 0 or tops[0] > max(self.topk_recorder['net_top1']) else False
  70. self.topk_recorder['net_top1'].append(tops[0])
  71. self.topk_recorder['net_top5'].append(tops[1])
  72. max_top1 = max(self.topk_recorder['net_top1'])
  73. max_top1_idx = self.topk_recorder['net_top1'].index(max_top1) + 1
  74. msg = 'Max [top1: {:>3.3f} (epoch: {:d})]'.format(max_top1, max_top1_idx)
  75. if self.ema:
  76. tops = self.test_net(self.net_E, name='net_E')
  77. self.is_best_ema = True if len(self.topk_recorder['net_E_top1']) == 0 or tops[0] > max(
  78. self.topk_recorder['net_E_top1']) else False
  79. self.topk_recorder['net_E_top1'].append(tops[0])
  80. self.topk_recorder['net_E_top5'].append(tops[1])
  81. max_top1_ema = max(self.topk_recorder['net_E_top1'])
  82. max_top1_idx_ema = self.topk_recorder['net_E_top1'].index(max_top1_ema) + 1
  83. msg += ' [top1-ema: {:>3.3f} (epoch: {:d})]'.format(max_top1_ema, max_top1_idx_ema)
  84. log_msg(self.logger, msg)