| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- import os
- import copy
- import shutil
- import datetime
- import torch
- from util.util import makedirs, log_cfg, able, log_msg, get_log_terms, update_log_term, accuracy
- from util.net import trans_state_dict, print_networks, get_timepc, reduce_tensor
- from optim.scheduler import get_scheduler
- from data import get_loader
- from model import get_model
- from optim import get_optim
- from loss import get_loss_terms
- from timm.data import Mixup
- from torch.nn.parallel import DistributedDataParallel as NativeDDP
- try:
- from apex import amp
- from apex.parallel import DistributedDataParallel as ApexDDP
- from apex.parallel import convert_syncbn_model as ApexSyncBN
- except:
- from timm.layers.norm_act import convert_sync_batchnorm as ApexSyncBN
- from timm.layers.norm_act import convert_sync_batchnorm as TIMMSyncBN
- from timm.utils import dispatch_clip_grad
- from util.net import get_loss_scaler, get_autocast, distribute_bn
- from ._base_trainer import BaseTrainer
- from . import TRAINER
- @TRAINER.register_module
- class CLSTrainer(BaseTrainer):
- def __init__(self, cfg):
- super(CLSTrainer, self).__init__(cfg)
-
- def set_input(self, inputs):
- self.imgs = inputs['img'].cuda()
- self.targets = inputs['target'].cuda()
- self.bs = self.imgs.shape[0]
-
- def forward(self, net=None):
- net = net if net is not None else self.net
- self.outputs = net(self.imgs)
- if not isinstance(self.outputs, dict):
- self.outputs = {'out': self.outputs, 'out_kd': self.outputs}
-
- def backward_term(self, loss_term, optim):
- optim.zero_grad()
- if self.loss_scaler:
- self.loss_scaler(loss_term, optim, clip_grad=self.cfg.loss.clip_grad, parameters=self.net.parameters(), create_graph=self.cfg.loss.create_graph)
- else:
- loss_term.backward(retain_graph=self.cfg.loss.retain_graph)
- if self.cfg.loss.clip_grad is not None:
- dispatch_clip_grad(self.net.parameters(), value=self.cfg.loss.clip_grad)
- optim.step()
- def optimize_parameters(self):
- if self.mixup_fn is not None:
- self.imgs, self.targets = self.mixup_fn(self.imgs, self.targets)
- with self.amp_autocast():
- self.forward()
- nan_or_inf_out = 1. if torch.any(torch.isnan(self.outputs['out'])) or torch.any(torch.isinf(self.outputs['out'])) else 0.
- nan_or_inf_out = reduce_tensor(nan_or_inf_out, self.world_size, mode='sum', sum_avg=False).clone().detach().item()
- nan_or_inf_out = True if nan_or_inf_out > 0. else False
- if nan_or_inf_out:
- self.nan_or_inf_cnt += 1
- log_msg(self.logger, f'NaN or Inf Found, total {self.nan_or_inf_cnt} times')
- self.check_bn()
- loss_ce = self.loss_terms['CE'](self.outputs['out'], self.targets) if not nan_or_inf_out else 0
- loss_kd = (self.loss_terms['KD'](self.outputs['out_kd'], self.imgs) if self.loss_terms.get('KD', None) else 0) if not nan_or_inf_out else 0
- self.backward_term((loss_ce + loss_kd) if not nan_or_inf_out else (0 * self.outputs['out'][0, 0]), self.optim)
- update_log_term(self.log_terms.get('CE'), reduce_tensor(loss_ce, self.world_size).clone().detach().item(), 1, self.master)
- update_log_term(self.log_terms.get('KD'), reduce_tensor(loss_kd, self.world_size).clone().detach().item(), 1, self.master)
- self._update_ema()
- @torch.no_grad()
- def test(self):
- tops = self.test_net(self.net, name='net')
- self.is_best = True if len(self.topk_recorder['net_top1']) == 0 or tops[0] > max(self.topk_recorder['net_top1']) else False
- self.topk_recorder['net_top1'].append(tops[0])
- self.topk_recorder['net_top5'].append(tops[1])
- max_top1 = max(self.topk_recorder['net_top1'])
- max_top1_idx = self.topk_recorder['net_top1'].index(max_top1) + 1
- msg = 'Max [top1: {:>3.3f} (epoch: {:d})]'.format(max_top1, max_top1_idx)
- if self.ema:
- tops = self.test_net(self.net_E, name='net_E')
- self.is_best_ema = True if len(self.topk_recorder['net_E_top1']) == 0 or tops[0] > max(
- self.topk_recorder['net_E_top1']) else False
- self.topk_recorder['net_E_top1'].append(tops[0])
- self.topk_recorder['net_E_top5'].append(tops[1])
- max_top1_ema = max(self.topk_recorder['net_E_top1'])
- max_top1_idx_ema = self.topk_recorder['net_E_top1'].index(max_top1_ema) + 1
- msg += ' [top1-ema: {:>3.3f} (epoch: {:d})]'.format(max_top1_ema, max_top1_idx_ema)
- log_msg(self.logger, msg)
|