main.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # --------------------------------------------------------
  2. # Modified by $@#Anonymous#@$
  3. # --------------------------------------------------------
  4. # Swin Transformer
  5. # Copyright (c) 2021 Microsoft
  6. # Licensed under The MIT License [see LICENSE for details]
  7. # Written by Ze Liu
  8. # --------------------------------------------------------
  9. import os
  10. import time
  11. import json
  12. import random
  13. import argparse
  14. import datetime
  15. import tqdm
  16. import numpy as np
  17. import torch
  18. import torch.backends.cudnn as cudnn
  19. import torch.distributed as dist
  20. from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
  21. from timm.utils import accuracy, AverageMeter
  22. from config import get_config
  23. from models import build_model
  24. from data import build_loader
  25. from utils.lr_scheduler import build_scheduler
  26. from utils.optimizer import build_optimizer
  27. from utils.logger import create_logger
  28. from utils.utils import NativeScalerWithGradNormCount, auto_resume_helper, reduce_tensor
  29. from utils.utils import load_checkpoint_ema, load_pretrained_ema, save_checkpoint_ema
  30. from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count
  31. from timm.utils import ModelEma as ModelEma
  32. if torch.multiprocessing.get_start_method() != "spawn":
  33. print(f"||{torch.multiprocessing.get_start_method()}||", end="")
  34. torch.multiprocessing.set_start_method("spawn", force=True)
  35. def str2bool(v):
  36. """
  37. Converts string to bool type; enables command line
  38. arguments in the format of '--arg1 true --arg2 false'
  39. """
  40. if isinstance(v, bool):
  41. return v
  42. if v.lower() in ('yes', 'true', 't', 'y', '1'):
  43. return True
  44. elif v.lower() in ('no', 'false', 'f', 'n', '0'):
  45. return False
  46. else:
  47. raise argparse.ArgumentTypeError('Boolean value expected.')
  48. def parse_option():
  49. parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
  50. parser.add_argument('--cfg', type=str, metavar="FILE", default="", help='path to config file', )
  51. parser.add_argument(
  52. "--opts",
  53. help="Modify config options by adding 'KEY VALUE' pairs. ",
  54. default=None,
  55. nargs='+',
  56. )
  57. # easy config modification
  58. parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
  59. parser.add_argument('--data-path', type=str, default="/dataset/ImageNet_ILSVRC2012", help='path to dataset')
  60. parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
  61. parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
  62. help='no: no cache, '
  63. 'full: cache all data, '
  64. 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
  65. parser.add_argument('--pretrained',
  66. help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
  67. parser.add_argument('--resume', help='resume from checkpoint')
  68. parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
  69. parser.add_argument('--use-checkpoint', action='store_true',
  70. help="whether to use gradient checkpointing to save memory")
  71. parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
  72. parser.add_argument('--output', default='output', type=str, metavar='PATH',
  73. help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
  74. parser.add_argument('--tag', default=time.strftime("%Y%m%d%H%M%S", time.localtime()), help='tag of experiment')
  75. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  76. parser.add_argument('--throughput', action='store_true', help='Test throughput only')
  77. parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
  78. parser.add_argument('--optim', type=str, help='overwrite optimizer if provided, can be adamw/sgd.')
  79. # EMA related parameters
  80. parser.add_argument('--model_ema', type=str2bool, default=True)
  81. parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
  82. parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
  83. parser.add_argument('--memory_limit_rate', type=float, default=-1, help='limitation of gpu memory use')
  84. args, unparsed = parser.parse_known_args()
  85. config = get_config(args)
  86. return args, config
  87. def main(config, args):
  88. dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
  89. logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
  90. model = build_model(config)
  91. if dist.get_rank() == 0:
  92. if hasattr(model, 'flops'):
  93. logger.info(str(model))
  94. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  95. logger.info(f"number of params: {n_parameters}")
  96. flops = model.flops()
  97. logger.info(f"number of GFLOPs: {flops / 1e9}")
  98. else:
  99. logger.info(flop_count_str(FlopCountAnalysis(model, (dataset_val[0][0][None],))))
  100. torch.cuda.empty_cache()
  101. dist.barrier()
  102. model.cuda()
  103. model_without_ddp = model
  104. model_ema = None
  105. if args.model_ema:
  106. # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
  107. model_ema = ModelEma(
  108. model,
  109. decay=args.model_ema_decay,
  110. device='cpu' if args.model_ema_force_cpu else '',
  111. resume='')
  112. print("Using EMA with decay = %.8f" % args.model_ema_decay)
  113. optimizer = build_optimizer(config, model, logger)
  114. model = torch.nn.parallel.DistributedDataParallel(model, broadcast_buffers=False)
  115. loss_scaler = NativeScalerWithGradNormCount()
  116. if config.TRAIN.ACCUMULATION_STEPS > 1:
  117. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
  118. else:
  119. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
  120. if config.AUG.MIXUP > 0.:
  121. # smoothing is handled with mixup label transform
  122. criterion = SoftTargetCrossEntropy()
  123. elif config.MODEL.LABEL_SMOOTHING > 0.:
  124. criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
  125. else:
  126. criterion = torch.nn.CrossEntropyLoss()
  127. max_accuracy = 0.0
  128. max_accuracy_ema = 0.0
  129. if config.TRAIN.AUTO_RESUME:
  130. resume_file = auto_resume_helper(config.OUTPUT)
  131. if resume_file:
  132. if config.MODEL.RESUME:
  133. logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
  134. config.defrost()
  135. config.MODEL.RESUME = resume_file
  136. config.freeze()
  137. logger.info(f'auto resuming from {resume_file}')
  138. else:
  139. logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
  140. if config.MODEL.RESUME:
  141. max_accuracy, max_accuracy_ema = load_checkpoint_ema(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger, model_ema)
  142. acc1, acc5, loss = validate(config, data_loader_val, model)
  143. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  144. if model_ema is not None:
  145. acc1_ema, acc5_ema, loss_ema = validate(config, data_loader_val, model_ema.ema)
  146. logger.info(f"Accuracy of the network ema on the {len(dataset_val)} test images: {acc1_ema:.1f}%")
  147. if config.EVAL_MODE:
  148. return
  149. if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
  150. load_pretrained_ema(config, model_without_ddp, logger, model_ema)
  151. acc1, acc5, loss = validate(config, data_loader_val, model)
  152. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  153. if model_ema is not None:
  154. acc1_ema, acc5_ema, loss_ema = validate(config, data_loader_val, model_ema.ema)
  155. logger.info(f"Accuracy of the network ema on the {len(dataset_val)} test images: {acc1_ema:.1f}%")
  156. if config.EVAL_MODE:
  157. return
  158. if config.THROUGHPUT_MODE and (dist.get_rank() == 0):
  159. logger.info(f"throughput mode ==============================")
  160. throughput(data_loader_val, model, logger)
  161. if model_ema is not None:
  162. torch.cuda.synchronize()
  163. torch.cuda.empty_cache()
  164. throughput(data_loader_val, model_ema.ema, logger)
  165. return
  166. logger.info("Start training")
  167. start_time = time.time()
  168. for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
  169. data_loader_train.sampler.set_epoch(epoch)
  170. train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler, model_ema)
  171. if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
  172. save_checkpoint_ema(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, model_ema, max_accuracy_ema)
  173. acc1, acc5, loss = validate(config, data_loader_val, model)
  174. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  175. max_accuracy = max(max_accuracy, acc1)
  176. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  177. if model_ema is not None:
  178. acc1_ema, acc5_ema, loss_ema = validate(config, data_loader_val, model_ema.ema)
  179. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1_ema:.1f}%")
  180. max_accuracy_ema = max(max_accuracy_ema, acc1_ema)
  181. logger.info(f'Max accuracy ema: {max_accuracy_ema:.2f}%')
  182. total_time = time.time() - start_time
  183. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  184. logger.info('Training time {}'.format(total_time_str))
  185. def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler, model_ema=None, model_time_warmup=50):
  186. model.train()
  187. optimizer.zero_grad()
  188. num_steps = len(data_loader)
  189. batch_time = AverageMeter()
  190. model_time = AverageMeter()
  191. data_time = AverageMeter()
  192. loss_meter = AverageMeter()
  193. norm_meter = AverageMeter()
  194. scaler_meter = AverageMeter()
  195. start = time.time()
  196. end = time.time()
  197. for idx, (samples, targets) in enumerate(data_loader):
  198. torch.cuda.reset_peak_memory_stats()
  199. samples = samples.cuda(non_blocking=True)
  200. targets = targets.cuda(non_blocking=True)
  201. if mixup_fn is not None:
  202. samples, targets = mixup_fn(samples, targets)
  203. data_time.update(time.time() - end)
  204. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  205. outputs = model(samples)
  206. loss = criterion(outputs, targets)
  207. loss = loss / config.TRAIN.ACCUMULATION_STEPS
  208. # this attribute is added by timm on one optimizer (adahessian)
  209. is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
  210. grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
  211. parameters=model.parameters(), create_graph=is_second_order,
  212. update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
  213. if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
  214. optimizer.zero_grad()
  215. lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
  216. if model_ema is not None:
  217. model_ema.update(model)
  218. loss_scale_value = loss_scaler.state_dict()["scale"]
  219. torch.cuda.synchronize()
  220. loss_meter.update(loss.item(), targets.size(0))
  221. if grad_norm is not None: # loss_scaler return None if not update
  222. norm_meter.update(grad_norm)
  223. scaler_meter.update(loss_scale_value)
  224. batch_time.update(time.time() - end)
  225. end = time.time()
  226. if idx > model_time_warmup:
  227. model_time.update(batch_time.val - data_time.val)
  228. if idx % config.PRINT_FREQ == 0:
  229. lr = optimizer.param_groups[0]['lr']
  230. wd = optimizer.param_groups[0]['weight_decay']
  231. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  232. etas = batch_time.avg * (num_steps - idx)
  233. logger.info(
  234. f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
  235. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
  236. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  237. f'data time {data_time.val:.4f} ({data_time.avg:.4f})\t'
  238. f'model time {model_time.val:.4f} ({model_time.avg:.4f})\t'
  239. f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  240. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  241. f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
  242. f'mem {memory_used:.0f}MB')
  243. epoch_time = time.time() - start
  244. logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
  245. @torch.no_grad()
  246. def validate(config, data_loader, model):
  247. criterion = torch.nn.CrossEntropyLoss()
  248. model.eval()
  249. batch_time = AverageMeter()
  250. loss_meter = AverageMeter()
  251. acc1_meter = AverageMeter()
  252. acc5_meter = AverageMeter()
  253. end = time.time()
  254. for idx, (images, target) in enumerate(data_loader):
  255. images = images.cuda(non_blocking=True)
  256. target = target.cuda(non_blocking=True)
  257. # compute output
  258. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  259. output = model(images)
  260. # measure accuracy and record loss
  261. loss = criterion(output, target)
  262. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  263. acc1 = reduce_tensor(acc1)
  264. acc5 = reduce_tensor(acc5)
  265. loss = reduce_tensor(loss)
  266. loss_meter.update(loss.item(), target.size(0))
  267. acc1_meter.update(acc1.item(), target.size(0))
  268. acc5_meter.update(acc5.item(), target.size(0))
  269. # measure elapsed time
  270. batch_time.update(time.time() - end)
  271. end = time.time()
  272. if idx % config.PRINT_FREQ == 0:
  273. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  274. logger.info(
  275. f'Test: [{idx}/{len(data_loader)}]\t'
  276. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  277. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  278. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  279. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  280. f'Mem {memory_used:.0f}MB')
  281. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  282. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  283. @torch.no_grad()
  284. def throughput(data_loader, model, logger):
  285. model.eval()
  286. for idx, (images, _) in enumerate(data_loader):
  287. images = images.cuda(non_blocking=True)
  288. batch_size = images.shape[0]
  289. for i in range(50):
  290. model(images)
  291. torch.cuda.synchronize()
  292. logger.info(f"throughput averaged with 30 times")
  293. tic1 = time.time()
  294. for i in range(30):
  295. model(images)
  296. torch.cuda.synchronize()
  297. tic2 = time.time()
  298. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  299. return
  300. if __name__ == '__main__':
  301. args, config = parse_option()
  302. if config.AMP_OPT_LEVEL:
  303. print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
  304. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  305. rank = int(os.environ["RANK"])
  306. world_size = int(os.environ['WORLD_SIZE'])
  307. print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
  308. else:
  309. rank = -1
  310. world_size = -1
  311. torch.cuda.set_device(rank)
  312. dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  313. dist.barrier()
  314. seed = config.SEED + dist.get_rank()
  315. torch.manual_seed(seed)
  316. torch.cuda.manual_seed(seed)
  317. np.random.seed(seed)
  318. random.seed(seed)
  319. cudnn.benchmark = True
  320. if True:
  321. torch.backends.cudnn.enabled = True
  322. torch.backends.cudnn.benchmark = True
  323. torch.backends.cudnn.deterministic = True
  324. # linear scale the learning rate according to total batch size, may not be optimal
  325. linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  326. linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  327. linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  328. # gradient accumulation also need to scale the learning rate
  329. if config.TRAIN.ACCUMULATION_STEPS > 1:
  330. linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
  331. linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
  332. linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
  333. config.defrost()
  334. config.TRAIN.BASE_LR = linear_scaled_lr
  335. config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
  336. config.TRAIN.MIN_LR = linear_scaled_min_lr
  337. config.freeze()
  338. # to make sure all the config.OUTPUT are the same
  339. config.defrost()
  340. if dist.get_rank() == 0:
  341. obj = [config.OUTPUT]
  342. # obj = [str(random.randint(0, 100))] # for test
  343. else:
  344. obj = [None]
  345. dist.broadcast_object_list(obj)
  346. dist.barrier()
  347. config.OUTPUT = obj[0]
  348. print(config.OUTPUT, flush=True)
  349. config.freeze()
  350. os.makedirs(config.OUTPUT, exist_ok=True)
  351. logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
  352. if dist.get_rank() == 0:
  353. path = os.path.join(config.OUTPUT, "config.json")
  354. with open(path, "w") as f:
  355. f.write(config.dump())
  356. logger.info(f"Full config saved to {path}")
  357. # print config
  358. logger.info(config.dump())
  359. logger.info(json.dumps(vars(args)))
  360. if args.memory_limit_rate > 0 and args.memory_limit_rate < 1:
  361. torch.cuda.set_per_process_memory_fraction(args.memory_limit_rate)
  362. usable_memory = torch.cuda.get_device_properties(0).total_memory * args.memory_limit_rate / 1e6
  363. print(f"===========> GPU memory is limited to {usable_memory}MB", flush=True)
  364. main(config, args)