clseval.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. import argparse
  2. import os
  3. import time
  4. import random
  5. from collections import OrderedDict
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. import torch.distributed as dist
  10. import torch.optim
  11. import torch.utils.data
  12. from timm.utils import accuracy, AverageMeter
  13. import logging
  14. logger = logging
  15. HOME = os.environ["HOME"].rstrip("/")
  16. def parse_options():
  17. parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
  18. parser.add_argument("--epochs", default=30, type=int)
  19. parser.add_argument("-b", "--batch-size", default=4096, type=int, dest="batch_size")
  20. parser.add_argument("--lr", "--learning-rate", default=30.0, type=float, dest="lr")
  21. parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
  22. parser.add_argument("--wd", "--weight-decay", default=0.0, type=float, dest="weight_decay")
  23. parser.add_argument("--reinit", action="store_true")
  24. parser.add_argument("-e", "--evaluate", action="store_true", dest="evaluate")
  25. parser.add_argument("--seed", default=0, type=int)
  26. parser.add_argument("--size", default=224, type=int, help="img size")
  27. parser.add_argument("--name", default="all", type=str, help="model name")
  28. args = parser.parse_args()
  29. print(args)
  30. return args
  31. def get_feats_train_dataloader(features, length=1281167,batch_size=128, distributed=False):
  32. feats = torch.load(open(features, "rb"))
  33. feats, tgts = feats["features"], feats["targets"].long()
  34. assert feats.shape[0] == length
  35. assert tgts.shape[0] == length
  36. class fds(torch.utils.data.Dataset):
  37. def __len__(self):
  38. return length
  39. def __getitem__(self, index):
  40. return feats[index], tgts[index]
  41. dataset_train = fds()
  42. sampler_train = torch.utils.data.DistributedSampler(
  43. dataset_train, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True
  44. ) if distributed else None
  45. data_loader_train = torch.utils.data.DataLoader(
  46. dataset_train,
  47. sampler=sampler_train,
  48. shuffle=(not distributed),
  49. batch_size=batch_size,
  50. num_workers=0,
  51. pin_memory=True,
  52. drop_last=True,
  53. )
  54. return data_loader_train
  55. def get_feats_eval_dataloader(features, length=50000, batch_size=128):
  56. feats = torch.load(open(features, "rb"))
  57. feats, tgts = feats["features"], feats["targets"].long()
  58. assert feats.shape[0] == length
  59. assert tgts.shape[0] == length
  60. class fds(torch.utils.data.Dataset):
  61. def __len__(self):
  62. return length
  63. def __getitem__(self, index):
  64. return feats[index], tgts[index]
  65. dataset_val = fds()
  66. data_loader_val = torch.utils.data.DataLoader(
  67. dataset_val,
  68. sampler=None,
  69. shuffle=False,
  70. batch_size=batch_size,
  71. num_workers=0,
  72. pin_memory=True,
  73. drop_last=False,
  74. )
  75. return data_loader_val
  76. # WARNING!!! acc score would be inaccurate if num_procs > 1, as sampler always pads the dataset
  77. # copied from https://github.com/microsoft/Swin-Transformer/blob/main/main.py
  78. @torch.no_grad()
  79. def validate(data_loader, model, AMP_ENABLE=True, verbose=True):
  80. criterion = torch.nn.CrossEntropyLoss()
  81. model.eval()
  82. batch_time = AverageMeter()
  83. loss_meter = AverageMeter()
  84. acc1_meter = AverageMeter()
  85. acc5_meter = AverageMeter()
  86. end = time.time()
  87. for idx, (images, target) in enumerate(data_loader):
  88. images = images.cuda(non_blocking=True)
  89. target = target.cuda(non_blocking=True)
  90. # compute output
  91. with torch.cuda.amp.autocast(enabled=AMP_ENABLE):
  92. output = model(images)
  93. # measure accuracy and record loss
  94. loss = criterion(output, target)
  95. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  96. loss_meter.update(loss.item(), target.size(0))
  97. acc1_meter.update(acc1.item(), target.size(0))
  98. acc5_meter.update(acc5.item(), target.size(0))
  99. # measure elapsed time
  100. batch_time.update(time.time() - end)
  101. end = time.time()
  102. if verbose:
  103. print(f'* Loss {loss_meter.avg:.4f} Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}', flush=True)
  104. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  105. def train(model, args, features_train, features_val, seed=0, state_dict=None, reinit=False, outdir="/tmp", val=False, lr=0.05, verbose=True):
  106. batch_size = args.batch_size
  107. print(args, dict(model=model, lr=lr, verbose=verbose, seed=seed, reinit=reinit), flush=True)
  108. assert isinstance(model, torch.nn.Linear)
  109. # model = torch.nn.Linear(args.dim, args.num_classes, bias=True)
  110. train_loader = get_feats_train_dataloader(features_train, batch_size=batch_size, length=1281167)
  111. val_loader = get_feats_eval_dataloader(features_val, batch_size=batch_size, length=50000)
  112. model = torch.nn.Sequential(OrderedDict(fc = model,)).cuda()
  113. criterion = torch.nn.CrossEntropyLoss()
  114. optimizer = torch.optim.SGD(
  115. model.parameters(),
  116. lr,
  117. momentum=args.momentum,
  118. weight_decay=args.weight_decay
  119. )
  120. # optimizer = torch.optim.AdamW(model.parameters(), lr)
  121. lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs * len(train_loader), eta_min=0)
  122. if state_dict is not None:
  123. model.fc.load_state_dict(state_dict)
  124. validate(val_loader, model)
  125. if seed is not None:
  126. assert isinstance(seed, int)
  127. random.seed(seed)
  128. np.random.seed(seed)
  129. torch.manual_seed(seed)
  130. torch.cuda.manual_seed(seed)
  131. torch.backends.cudnn.benchmark = True
  132. if reinit:
  133. model.fc.weight.data.normal_(mean=0.0, std=0.01)
  134. model.fc.bias.data.zero_()
  135. validate(val_loader, model, verbose=True)
  136. if val:
  137. return
  138. maxacc1 = [0, 0, 0, 0]
  139. for epoch in range(0, args.epochs):
  140. loss_meter = AverageMeter()
  141. acc1_meter = AverageMeter()
  142. acc5_meter = AverageMeter()
  143. model.train()
  144. for idx, (images, target) in enumerate(train_loader):
  145. images = images.cuda(non_blocking=True)
  146. target = target.cuda(non_blocking=True)
  147. output = model(images)
  148. loss = criterion(output, target)
  149. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  150. loss_meter.update(loss.item(), images.size(0))
  151. acc1_meter.update(acc1.item(), images.size(0))
  152. acc5_meter.update(acc5.item(), images.size(0))
  153. optimizer.zero_grad()
  154. loss.backward()
  155. optimizer.step()
  156. lr_scheduler.step()
  157. if verbose:
  158. print(
  159. f'Train[{epoch}/{args.epochs} : {len(train_loader)}]: '
  160. f'Loss {loss_meter.avg:.4f} '
  161. f'Acc@1 {acc1_meter.avg:.3f} '
  162. f'Acc@5 {acc5_meter.avg:.3f} ', flush=True)
  163. acc1, acc5, loss = validate(val_loader, model, verbose=verbose)
  164. if acc1 > maxacc1[0]:
  165. maxacc1 = [acc1, acc5, loss, epoch]
  166. print(f"max acc: {maxacc1[0:2]}, loss: {maxacc1[2]}, epoch {maxacc1[3]}", flush=True)
  167. torch.save({
  168. "epoch": args.epochs,
  169. "state_dict": model.state_dict(),
  170. }, os.path.join(outdir, f"ckpt_epoch_{args.epochs}.pth"))
  171. if __name__ == "__main__":
  172. args = parse_options()
  173. vmambav2tiny = dict(
  174. name = "vmambav2tiny",
  175. model = nn.Linear(768, 1000, bias=True),
  176. ckpt = f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_tiny_0230s/vssm1_tiny_0230s_ckpt_epoch_264.pth",
  177. state_dict = lambda sd: {
  178. "weight": sd["model"]["classifier.head.weight"],
  179. "bias": sd["model"]["classifier.head.bias"],
  180. }
  181. )
  182. vmambav2l5tiny = dict(
  183. name = "vmambav2l5tiny",
  184. model = nn.Linear(768, 1000, bias=True),
  185. ckpt = f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm1/classification/vssm1_tiny_0230/vssm1_tiny_0230_ckpt_epoch_262.pth",
  186. state_dict = lambda sd: {
  187. "weight": sd["model"]["classifier.head.weight"],
  188. "bias": sd["model"]["classifier.head.bias"],
  189. }
  190. )
  191. vmambav0tiny = dict(
  192. name = "vmambav0tiny",
  193. model = nn.Linear(768, 1000, bias=True),
  194. ckpt = f"{HOME}/Workspace/PylanceAware/ckpts/publish/vssm/classification/vssmtiny/vssmtiny_dp01_ckpt_epoch_292.pth",
  195. state_dict = lambda sd: {
  196. "weight": sd["model"]["head.weight"],
  197. "bias": sd["model"]["head.bias"],
  198. }
  199. )
  200. resnet50 = dict(
  201. name = "resnet50",
  202. model = nn.Linear(2048, 1000, bias=True),
  203. ckpt = f"{HOME}/.cache/torch/hub/checkpoints/resnet50_8xb32_in1k_20210831-ea4938fc.pth",
  204. state_dict = lambda sd: {
  205. "weight": sd["state_dict"]["head.fc.weight"],
  206. "bias": sd["state_dict"]["head.fc.bias"],
  207. }
  208. )
  209. deitsmall = dict(
  210. name = "deitsmall",
  211. model = nn.Linear(384, 1000, bias=True),
  212. ckpt = f"{HOME}/.cache/torch/hub/checkpoints/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth",
  213. state_dict = lambda sd: {
  214. "weight": sd["state_dict"]["head.layers.head.weight"],
  215. "bias": sd["state_dict"]["head.layers.head.bias"],
  216. }
  217. )
  218. convnexttiny = dict(
  219. name = "convnexttiny",
  220. model = nn.Linear(768, 1000, bias=True),
  221. ckpt = f"{HOME}/packs/ckpts/convnext_tiny_1k_224_ema.pth",
  222. state_dict = lambda sd: {
  223. "weight": sd["model"]["head.weight"],
  224. "bias": sd["model"]["head.bias"],
  225. }
  226. )
  227. swintiny = dict(
  228. name = "swintiny",
  229. model = nn.Linear(768, 1000, bias=True),
  230. ckpt = f"{HOME}/.cache/torch/hub/checkpoints/swin_tiny_224_b16x64_300e_imagenet_20210616_090925-66df6be6.pth",
  231. state_dict = lambda sd: {
  232. "weight": sd["state_dict"]["head.fc.weight"],
  233. "bias": sd["state_dict"]["head.fc.bias"],
  234. }
  235. )
  236. hivittiny = dict(
  237. name = "hivittiny",
  238. model = nn.Linear(384, 1000, bias=True),
  239. ckpt = f"{HOME}/packs/ckpts/hivit-tiny-p16_8xb128_in1k/epoch_295.pth",
  240. state_dict = lambda sd: {
  241. "weight": sd["state_dict"]["head.fc.weight"],
  242. "bias": sd["state_dict"]["head.fc.bias"],
  243. }
  244. )
  245. interntiny = dict(
  246. name = "interntiny",
  247. model = nn.Linear(768, 1000, bias=True),
  248. ckpt = f"{HOME}/packs/ckpts/internimage_t_1k_224.pth",
  249. state_dict = lambda sd: {
  250. "weight": sd["model"]["head.weight"],
  251. "bias": sd["model"]["head.bias"],
  252. }
  253. )
  254. xcittiny = dict(
  255. name = "xcittiny",
  256. model = nn.Linear(384, 1000, bias=True),
  257. ckpt = f"{HOME}/packs/ckpts/xcit_small_12_p16_224.pth",
  258. state_dict = lambda sd: {
  259. "weight": sd["model"]["head.weight"],
  260. "bias": sd["model"]["head.bias"],
  261. }
  262. )
  263. deitbase = dict(
  264. name = "deitbase ",
  265. model = nn.Linear(768, 1000, bias=True),
  266. ckpt = f"{HOME}/.cache/torch/hub/checkpoints/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth",
  267. state_dict = lambda sd: {
  268. "weight": sd["state_dict"]["head.layers.head.weight"],
  269. "bias": sd["state_dict"]["head.layers.head.bias"],
  270. }
  271. )
  272. vims = dict(
  273. name = "vims",
  274. model = nn.Linear(384, 1000, bias=True),
  275. ckpt = f"{HOME}/packs/ckpts/vim_s_midclstok_80p5acc.pth",
  276. state_dict = lambda sd: {
  277. "weight": sd["model"]["head.weight"],
  278. "bias": sd["model"]["head.bias"],
  279. }
  280. )
  281. names = {}
  282. for col in [vmambav2tiny, vmambav2l5tiny, vmambav0tiny, swintiny, convnexttiny, hivittiny, deitsmall, resnet50, interntiny, xcittiny, deitbase, vims]:
  283. names.update({col["name"]: col})
  284. size = 224
  285. model = col["model"]
  286. feature_train = f"{HOME}/ckpts/feats/merge{size}/{col['name']}_sz{size}_train.pth"
  287. feature_val = f"{HOME}/ckpts/feats/merge{size}/{col['name']}_sz{size}_val.pth"
  288. state_dict = col["state_dict"](torch.load(col["ckpt"], map_location=torch.device("cpu")))
  289. if args.name == "all":
  290. # for col in [vmambav2tiny, vmambav2l5tiny, vmambav0tiny, swintiny, convnexttiny, hivittiny, deitsmall, resnet50, interntiny, xcittiny, vims]:
  291. for col in [vims]:
  292. for size, lr in zip([224, 288, 384, 512, 640, 768, 1024], [0.05, 0.05, 0.05, 0.2, 0.5, 0.5, 0.5]):
  293. model = col["model"]
  294. feature_train = f"{HOME}/ckpts/feats/merge{size}/{col['name']}_sz{size}_train.pth"
  295. feature_val = f"{HOME}/ckpts/feats/merge{size}/{col['name']}_sz{size}_val.pth"
  296. state_dict = col["state_dict"](torch.load(col["ckpt"], map_location=torch.device("cpu")))
  297. train(
  298. model=model, args=args, features_train=feature_train, features_val=feature_val,
  299. state_dict=state_dict,
  300. reinit=args.reinit,
  301. val=args.evaluate,
  302. lr=lr,
  303. verbose=False,
  304. )
  305. else:
  306. size = args.size
  307. col = names[args.name]
  308. model = col["model"]
  309. feature_train = f"{HOME}/ckpts/feats/merge{size}/{col['name']}_sz{size}_train.pth"
  310. feature_val = f"{HOME}/ckpts/feats/merge{size}/{col['name']}_sz{size}_val.pth"
  311. state_dict = col["state_dict"](torch.load(col["ckpt"], map_location=torch.device("cpu")))
  312. train(
  313. model=model, args=args, features_train=feature_train, features_val=feature_val,
  314. state_dict=state_dict,
  315. reinit=args.reinit,
  316. val=args.evaluate,
  317. lr = args.lr,
  318. )