eval.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import time
  2. import tqdm
  3. import torch
  4. import torch.utils.data
  5. import argparse
  6. import os
  7. import sys
  8. import logging
  9. from functools import partial
  10. from torchvision import datasets, transforms
  11. from torchvision.models.vision_transformer import EncoderBlock
  12. from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count
  13. import torch
  14. import torch.nn as nn
  15. import torch.backends.cudnn as cudnn
  16. import torch.distributed as dist
  17. from torch.utils.data import DataLoader, SequentialSampler, DistributedSampler
  18. import math
  19. logging.basicConfig(level=logging.INFO)
  20. logger = logging
  21. from timm.utils import accuracy, AverageMeter
  22. from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  23. HOME = os.environ["HOME"].rstrip("/")
  24. basicpath = os.path.abspath("../VMamba/analyze").rstrip("/")
  25. basicpath = os.path.abspath(os.path.dirname(__file__)).rstrip("/")
  26. # this mode will greatly inference the speed!
  27. torch.backends.cudnn.enabled = True
  28. torch.backends.cudnn.benchmark = True
  29. torch.backends.cudnn.deterministic = True
  30. from utils import ExtractFeatures, BuildModels
  31. from analyze_for_vim import ExtraDev
  32. def import_abspy(name="models", path="classification/"):
  33. import sys
  34. import importlib
  35. path = os.path.abspath(path)
  36. assert os.path.isdir(path)
  37. sys.path.insert(0, path)
  38. module = importlib.import_module(name)
  39. sys.path.pop(0)
  40. return module
  41. # copied from https://github.com/microsoft/Swin-Transformer/blob/main/main.py
  42. def reduce_tensor(tensor):
  43. rt = tensor.clone()
  44. dist.all_reduce(rt, op=dist.ReduceOp.SUM)
  45. rt /= dist.get_world_size()
  46. return rt
  47. # WARNING!!! acc score would be inaccurate if num_procs > 1, as sampler always pads the dataset
  48. # copied from https://github.com/microsoft/Swin-Transformer/blob/main/main.py
  49. @torch.no_grad()
  50. def validate(config, data_loader, model):
  51. criterion = torch.nn.CrossEntropyLoss()
  52. model.eval()
  53. batch_time = AverageMeter()
  54. loss_meter = AverageMeter()
  55. acc1_meter = AverageMeter()
  56. acc5_meter = AverageMeter()
  57. end = time.time()
  58. for idx, (images, target) in enumerate(data_loader):
  59. images = images.cuda(non_blocking=True)
  60. target = target.cuda(non_blocking=True)
  61. # compute output
  62. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  63. output = model(images)
  64. # measure accuracy and record loss
  65. loss = criterion(output, target)
  66. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  67. acc1 = reduce_tensor(acc1)
  68. acc5 = reduce_tensor(acc5)
  69. loss = reduce_tensor(loss)
  70. loss_meter.update(loss.item(), target.size(0))
  71. acc1_meter.update(acc1.item(), target.size(0))
  72. acc5_meter.update(acc5.item(), target.size(0))
  73. # measure elapsed time
  74. batch_time.update(time.time() - end)
  75. end = time.time()
  76. if idx % config.PRINT_FREQ == 0:
  77. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  78. logger.info(
  79. f'Test: [{idx}/{len(data_loader)}]\t'
  80. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  81. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  82. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  83. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  84. f'Mem {memory_used:.0f}MB')
  85. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  86. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  87. def get_dataloader(batch_size=64, root="./val", img_size=224, sequential=True):
  88. size = int((256 / 224) * img_size)
  89. transform = transforms.Compose([
  90. transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
  91. transforms.CenterCrop((img_size, img_size)),
  92. transforms.ToTensor(),
  93. transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
  94. ])
  95. dataset = datasets.ImageFolder(root, transform=transform)
  96. if sequential:
  97. sampler = torch.utils.data.SequentialSampler(dataset)
  98. else:
  99. sampler = torch.utils.data.DistributedSampler(dataset)
  100. data_loader = torch.utils.data.DataLoader(
  101. dataset, sampler=sampler,
  102. batch_size=batch_size,
  103. shuffle=False,
  104. num_workers=0,
  105. pin_memory=True,
  106. drop_last=False
  107. )
  108. return data_loader
  109. def _validate(
  110. model: nn.Module = None,
  111. freq=10,
  112. amp=True,
  113. img_size=224,
  114. batch_size=128,
  115. data_path="/dataset/ImageNet2012",
  116. ):
  117. class Args():
  118. AMP_ENABLE = amp
  119. PRINT_FREQ = freq
  120. config = Args()
  121. model.cuda().eval()
  122. model = torch.nn.parallel.DistributedDataParallel(model)
  123. _batch_size = batch_size
  124. while _batch_size > 0:
  125. try:
  126. _dataloader = get_dataloader(
  127. batch_size=_batch_size,
  128. root=os.path.join(os.path.abspath(data_path), "val"),
  129. img_size=img_size,
  130. sequential=False,
  131. )
  132. logging.info(f"starting loop: img_size {img_size}; len(dataset) {len(_dataloader.dataset)}")
  133. validate(config, data_loader=_dataloader, model=model)
  134. break
  135. except:
  136. _batch_size = _batch_size // 2
  137. print(f"batch_size {_batch_size}", flush=True)
  138. def _extract_feature(data_path="ImageNet_ILSVRC2012", start=0, end=200, step=-1, img_size=224, batch_size=16, train=True, aug=False):
  139. if False:
  140. resnet50 = BuildModels.build_resnet_mmpretrain(with_ckpt=True, remove_head=True, scale="r50", size=img_size).cuda().eval()
  141. deitsmall = BuildModels.build_deit_mmpretrain(with_ckpt=True, remove_head=True, scale="small", size=img_size).cuda().eval()
  142. vmambav0tiny = BuildModels.build_vmamba(with_ckpt=True, remove_head=True, scale="tv0").cuda().eval()
  143. vmambav2l5tiny = BuildModels.build_vmamba(with_ckpt=True, remove_head=True, scale="tv1").cuda().eval()
  144. vmambav2tiny = BuildModels.build_vmamba(with_ckpt=True, remove_head=True, scale="tv2").cuda().eval()
  145. convnexttiny = BuildModels.build_convnext(with_ckpt=True, remove_head=True, scale="tiny").cuda().eval()
  146. swintiny = BuildModels.build_swin_mmpretrain(with_ckpt=True, remove_head=True, scale="tiny", size=img_size).cuda().eval()
  147. hivittiny = BuildModels.build_hivit_mmpretrain(with_ckpt=True, remove_head=True, scale="tiny", size=img_size).cuda().eval()
  148. interntiny = BuildModels.build_intern(with_ckpt=True, remove_head=True, scale="tiny").cuda().eval()
  149. xcittiny = BuildModels.build_xcit(with_ckpt=True, remove_head=True, scale="tiny", size=img_size).cuda().eval()
  150. deitbase = BuildModels.build_deit_mmpretrain(with_ckpt=True, remove_head=True, scale="base", size=img_size).cuda().eval()
  151. if True:
  152. vims = ExtraDev.build_vim_for_throughput(with_ckpt=True, remove_head=True, size=img_size).cuda().eval()
  153. if True:
  154. if step > 0:
  155. starts = list(range(start, end, step))
  156. ends = [s + step for s in starts]
  157. assert ends[-1] >= end
  158. ends[-1] = end
  159. print(f"multiple ranges: {starts} {ends} ==============", flush=True)
  160. else:
  161. starts, ends = [start], [end]
  162. for s, e in zip(starts, ends):
  163. ExtractFeatures.extract_feature(
  164. backbones=dict(
  165. # vmambav2tiny = vmambav2tiny,
  166. # convnexttiny = convnexttiny,
  167. # swintiny = swintiny,
  168. # interntiny = interntiny,
  169. # vmambav0tiny = vmambav0tiny,
  170. # vmambav2l5tiny = vmambav2l5tiny,
  171. # deitsmall = deitsmall,
  172. # hivittiny = hivittiny,
  173. # resnet50 = resnet50,
  174. # xcittiny = xcittiny,
  175. # deitbase = deitbase,
  176. vims = vims,
  177. ),
  178. dims=dict(
  179. # vmambav2tiny = 768,
  180. # convnexttiny = 768,
  181. # swintiny = 768,
  182. # interntiny = 768,
  183. # vmambav0tiny = 768,
  184. # vmambav2l5tiny = 768,
  185. # deitsmall = 384,
  186. # hivittiny = 384,
  187. # resnet50 = 2048,
  188. # xcittiny = 384,
  189. # deitbase = 768,
  190. vims = 384,
  191. ),
  192. batch_size=batch_size,
  193. img_size=img_size,
  194. data_path=data_path,
  195. ranges=(s, e),
  196. train=train,
  197. aug=aug,
  198. )
  199. def main():
  200. parser = argparse.ArgumentParser()
  201. parser.add_argument('--batch-size', type=int, default=32, help="batch size for single GPU")
  202. parser.add_argument('--data-path', type=str, default="ImageNet_ILSVRC2012", help='path to dataset')
  203. parser.add_argument('--mode', type=str, default="", help='model name')
  204. parser.add_argument('--func', type=str, default="", help='function')
  205. parser.add_argument('--start', type=int, default=0, help='start range')
  206. parser.add_argument('--end', type=int, default=200, help='end range')
  207. parser.add_argument('--step', type=int, default=-1, help='step range')
  208. parser.add_argument('--size', type=int, default=224, help='image size')
  209. parser.add_argument('--batch_size', type=int, default=16, help='batch_size')
  210. parser.add_argument('--val', action="store_true", help='...')
  211. parser.add_argument('--aug', action="store_true", help='...')
  212. args = parser.parse_args()
  213. print(args, flush=True)
  214. _extract_feature(args.data_path, args.start, args.end, args.step, args.size, args.batch_size, (not args.val), args.aug)
  215. def run_code_dist_one(func):
  216. if torch.cuda.device_count() > 1:
  217. print("WARNING!!! acc score would be inaccurate if num_procs > 1, as sampler always pads the dataset")
  218. # os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  219. # print(torch.cuda.device_count())
  220. exit()
  221. dist.init_process_group(backend='nccl', init_method='env://', world_size=-1, rank=-1)
  222. else:
  223. os.environ['MASTER_ADDR'] = "127.0.0.1"
  224. os.environ['MASTER_PORT'] = "61234"
  225. while True:
  226. try:
  227. dist.init_process_group(backend='nccl', init_method='env://', world_size=1, rank=0)
  228. break
  229. except Exception as e:
  230. print(e, flush=True)
  231. os.environ['MASTER_PORT'] = f"{int(os.environ['MASTER_PORT']) - 1}"
  232. torch.cuda.set_device(dist.get_rank())
  233. dist.barrier()
  234. func()
  235. if __name__ == "__main__":
  236. run_code_dist_one(main)