net.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import os
  2. import random
  3. import shutil
  4. import copy
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.backends.cudnn as cudnn
  10. import torch.distributed as dist
  11. import math
  12. import time
  13. try:
  14. from collections import Iterable
  15. except ImportError:
  16. from collections.abc import Iterable
  17. from timm.utils.agc import adaptive_clip_grad
  18. from util.util import log_msg
  19. from fvcore.nn import FlopCountAnalysis, flop_count_table
  20. from timm.utils import NativeScaler, ApexScaler
  21. from contextlib import suppress, contextmanager
  22. def init_training(cfg):
  23. # ---------- cudnn ----------
  24. if not torch.cuda.is_available():
  25. print('==> GPU error')
  26. exit(0)
  27. torch.cuda.empty_cache()
  28. if cfg.trainer.cuda_deterministic: # slower, more reproducible
  29. cudnn.deterministic = True
  30. cudnn.benchmark = False
  31. else: # faster, less reproducible
  32. cudnn.deterministic = False
  33. cudnn.benchmark = True
  34. # ---------- dist ----------
  35. cfg.dist = True
  36. cfg.world_size, cfg.rank, cfg.local_rank = 1, 0, 0
  37. cfg.ngpus_per_node = torch.cuda.device_count()
  38. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  39. cfg.world_size = int(os.environ['WORLD_SIZE'])
  40. cfg.rank = int(os.environ["RANK"])
  41. cfg.local_rank = int(os.environ['LOCAL_RANK'])
  42. cfg.nnodes = cfg.world_size // cfg.ngpus_per_node
  43. elif 'SLURM_PROCID' in os.environ:
  44. cfg.rank = int(os.environ['SLURM_PROCID'])
  45. cfg.local_rank = cfg.rank % cfg.ngpus_per_node
  46. cfg.nnodes = cfg.world_size // cfg.ngpus_per_node
  47. else:
  48. cfg.dist = False
  49. cfg.nnodes = 1
  50. if cfg.dist:
  51. torch.cuda.set_device(cfg.local_rank)
  52. cfg.master = cfg.rank == cfg.logger_rank
  53. cfg.dist_backend = 'nccl'
  54. torch.distributed.init_process_group(backend=cfg.dist_backend, init_method=cfg.dist_url)
  55. torch.distributed.barrier()
  56. else:
  57. cfg.master = True
  58. # ---------- seed ----------
  59. seed = cfg.seed + cfg.local_rank
  60. np.random.seed(seed)
  61. random.seed(seed)
  62. torch.manual_seed(seed)
  63. torch.cuda.manual_seed(seed)
  64. # ---------- dataset ----------
  65. if cfg.trainer.data.batch_size:
  66. cfg.trainer.data.batch_size_per_gpu = cfg.trainer.data.batch_size // cfg.world_size
  67. assert cfg.trainer.data.batch_size_per_gpu * cfg.world_size == cfg.trainer.data.batch_size
  68. else:
  69. cfg.trainer.data.batch_size = cfg.trainer.data.batch_size_per_gpu * cfg.world_size
  70. if cfg.trainer.data.batch_size_test:
  71. cfg.trainer.data.batch_size_per_gpu_test = cfg.trainer.data.batch_size_test // cfg.world_size
  72. assert cfg.trainer.data.batch_size_per_gpu_test * cfg.world_size == cfg.trainer.data.batch_size_test
  73. else:
  74. cfg.trainer.data.batch_size_test = cfg.trainer.data.batch_size_per_gpu_test * cfg.world_size
  75. cfg.trainer.data.num_workers = cfg.trainer.data.num_workers_per_gpu * cfg.world_size
  76. def init_modules(modules, w_init='xavier_normal'):
  77. if w_init == "normal":
  78. _init = torch.nn.init.normal_
  79. elif w_init == "xavier_normal":
  80. _init = torch.nn.init.xavier_normal_
  81. elif w_init == "xavier_uniform":
  82. _init = torch.nn.init.xavier_uniform_
  83. elif w_init == "kaiming_normal":
  84. _init = torch.nn.init.kaiming_normal_
  85. elif w_init == "kaiming_uniform":
  86. _init = torch.nn.init.kaiming_uniform_
  87. elif w_init == "orthogonal":
  88. _init = torch.nn.init.orthogonal_
  89. else:
  90. raise NotImplementedError
  91. if isinstance(modules, Iterable):
  92. for m in modules:
  93. if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
  94. _init(m.weight)
  95. if m.bias is not None:
  96. torch.nn.init.zeros_(m.bias)
  97. if isinstance(m, (nn.LSTM, nn.GRU)):
  98. for name, param in m.named_parameters():
  99. if 'bias' in name:
  100. nn.init.zeros_(param)
  101. elif 'weight' in name:
  102. _init(param)
  103. def trans_state_dict(state_dict, dist=True):
  104. state_dict_modify = dict()
  105. if dist:
  106. for k, v in state_dict.items():
  107. k = k if k.startswith('module') else 'module.'+k
  108. state_dict_modify[k] = v
  109. else:
  110. for k, v in state_dict.items():
  111. k = k[7:] if k.startswith('module') else k
  112. state_dict_modify[k] = v
  113. return state_dict_modify
  114. def dispatch_clip_grad(parameters, value, mode='norm', norm_type=2.0):
  115. if mode == 'norm':
  116. torch.nn.utils.clip_grad_norm_(parameters, value, norm_type=norm_type)
  117. elif mode == 'value':
  118. torch.nn.utils.clip_grad_value_(parameters, value)
  119. elif mode == 'agc':
  120. adaptive_clip_grad(parameters, value, norm_type=norm_type)
  121. else:
  122. raise ValueError('invalid clip mode: {}'.format(mode))
  123. def get_params(model, names):
  124. params = []
  125. for name in names:
  126. params.extend(list(model.__getattribute__(name).parameters()))
  127. return params
  128. def get_timepc(cuda_synchronize=False):
  129. if torch.cuda.is_available() and cuda_synchronize:
  130. torch.cuda.synchronize()
  131. return time.perf_counter()
  132. def set_requires_grad(model, requires_grad=False):
  133. for p in model.parameters():
  134. p.requires_grad = requires_grad
  135. def print_networks(models, size, logger):
  136. for model in models:
  137. result = '\n' + '-' * 36 + ' {} '.format(type(model).__name__) + '-' * 36 + '\n'
  138. # total_num_params = 0
  139. # for i, (name, child) in enumerate(model.named_children()):
  140. # num_params = sum([p.numel() for p in child.parameters()]) / 1e6
  141. # total_num_params += num_params
  142. # result += '{}: {:<.3f}M\n'.format(name, num_params)
  143. # for i, (grandname, grandchild) in enumerate(child.named_children()):
  144. # num_params = sum([p.numel() for p in grandchild.parameters()]) / 1e6
  145. # result += '==> {}: {:<3.3f}M\n'.format(grandname, num_params)
  146. # total_num_params_with_parameter_vars = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
  147. # result += '[Network {}] Total number of parameters: {:<.3f}M (with parameter_vars: {:<.3f}M)\n'.format(type(model).__name__, total_num_params, total_num_params_with_parameter_vars)
  148. flops = FlopCountAnalysis(model, torch.randn([1, 3, size, size], dtype=list(model.parameters())[0].dtype, device=list(model.parameters())[0].device))
  149. result += '{}\n'.format(flop_count_table(flops, max_depth=5))
  150. result += '-' * (72 + 2 + len(type(model).__name__))
  151. log_msg(logger, result)
  152. def reduce_tensor(tensor, world_size, mode='sum', sum_avg=True, rank=0):
  153. if isinstance(tensor, torch.Tensor):
  154. tensor_ = tensor.detach()
  155. if tensor_.device == torch.device('cpu'):
  156. tensor_ = tensor_.cuda()
  157. else:
  158. tensor_ = torch.tensor(tensor).float().cuda()
  159. if world_size == 1:
  160. return tensor_
  161. if mode == 'sum':
  162. dist.barrier()
  163. dist.all_reduce(tensor_, op=torch.distributed.ReduceOp.SUM, )
  164. if sum_avg:
  165. tensor_ /= world_size
  166. tensor_out = tensor_
  167. elif mode == 'cat':
  168. size = [1] * len(tensor_.shape)
  169. size[0] = world_size
  170. tensor_out = torch.zeros_like(tensor_, dtype=tensor_.dtype, device=tensor_.device)
  171. tensor_out = tensor_out.repeat(size)
  172. B = tensor_.shape[0]
  173. tensor_out[rank * B:(rank+1) * B] = tensor_
  174. dist.barrier()
  175. dist.all_reduce(tensor_out, op=torch.distributed.ReduceOp.SUM, )
  176. elif mode == 'and':
  177. dist.barrier()
  178. dist.all_reduce(tensor_, op=torch.distributed.ReduceOp.BAND, )
  179. tensor_out = tensor_
  180. elif mode == 'or':
  181. dist.barrier()
  182. dist.all_reduce(tensor_, op=torch.distributed.ReduceOp.BOR, )
  183. tensor_out = tensor_
  184. else:
  185. raise 'invalid reduce mode: {}'.format(mode)
  186. return tensor_out
  187. def distribute_bn(model, world_size, dist_bn):
  188. # ensure every node has the same running bn stats
  189. model = model.module if hasattr(model, 'module') else model
  190. for bn_name, bn_buf in model.named_buffers(recurse=True):
  191. if ('running_mean' in bn_name) or ('running_var' in bn_name):
  192. if dist_bn == 'reduce':
  193. # average bn stats across whole group
  194. torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
  195. bn_buf /= float(world_size)
  196. elif dist_bn == 'broadcast':
  197. # broadcast bn stats from rank 0 to whole group
  198. torch.distributed.broadcast(bn_buf, 0)
  199. else:
  200. pass
  201. def get_loss_scaler(scaler='native'):
  202. scaler_dict = {
  203. 'none': None,
  204. 'native': NativeScaler(),
  205. 'apex': ApexScaler(),
  206. }
  207. return scaler_dict[scaler]
  208. @contextmanager
  209. def placeholder():
  210. yield
  211. def get_autocast(autocast='native'):
  212. autocast_dict = {
  213. 'none': placeholder,
  214. 'native': torch.cuda.amp.autocast,
  215. 'apex': placeholder,
  216. }
  217. return autocast_dict[autocast]
  218. def get_net_params(net, requires_grad=True):
  219. num_params = 0
  220. for param in net.parameters():
  221. if requires_grad and param.requires_grad:
  222. num_params += param.numel()
  223. return num_params