loss.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. import os
  2. import shutil
  3. import torch
  4. import numpy as np
  5. def get_acc_convnext(f: list):
  6. if isinstance(f, str):
  7. f = open(f, "r").readlines()
  8. emaaccs = []
  9. accs = []
  10. for i, line in enumerate(f):
  11. if "* Acc" in line and ("Accuracy of the model EMA on" in f[i + 1]):
  12. l: str = line.strip(" ").split(" ") # [*, Acc@1, 0.642, Acc@5, 2.780, ...]
  13. emaaccs.append(dict(acc1=float(l[2]), acc5=float(l[4])))
  14. elif "* Acc" in line and ("Accuracy of the model on" in f[i + 1]):
  15. l: str = line.strip(" ").split(" ") # [*, Acc@1, 0.642, Acc@5, 2.780, ...]
  16. accs.append(dict(acc1=float(l[2]), acc5=float(l[4])))
  17. accs = dict(acc1=[a['acc1'] for a in accs], acc5=[a['acc5'] for a in accs])
  18. emaaccs = dict(acc1=[a['acc1'] for a in emaaccs], acc5=[a['acc5'] for a in emaaccs])
  19. x_axis = range(len(accs['acc1']))
  20. return x_axis, accs, emaaccs
  21. def get_loss_convnext(f: list, x1e=torch.tensor(list(range(0, 625, 10)) + [624]).view(1, -1) / 625, scale=1):
  22. if isinstance(f, str):
  23. f = open(f, "r").readlines()
  24. avglosses = []
  25. losses = []
  26. for i, line in enumerate(f):
  27. if "Epoch: [" in line and ("loss:" in line):
  28. l = line.split("loss:")[1].strip(" ").split(" ")[:2]
  29. losses.append(float(l[0]))
  30. avglosses.append(float(l[1].split(")")[0].strip("()")))
  31. x = x1e
  32. x = x.repeat(len(losses) // x.shape[1] + 1, 1)
  33. x = x + torch.arange(0, x.shape[0]).view(-1, 1)
  34. x = x.flatten().tolist()
  35. x_axis = x[:len(losses)]
  36. losses = [l * scale for l in losses]
  37. avglosses = [l * scale for l in avglosses]
  38. return x_axis, losses, avglosses
  39. def get_acc_swin(f: list, split_ema=False):
  40. if isinstance(f, str):
  41. f = open(f, "r").readlines()
  42. emaaccs = None
  43. accs = []
  44. for i, line in enumerate(f):
  45. if "* Acc" in line:
  46. l: str = line.split("INFO")[-1].strip(" ").split(" ") # [*, Acc@1, 0.642, Acc@5, 2.780, ...]
  47. accs.append(dict(acc1=float(l[2]), acc5=float(l[4])))
  48. accs = dict(acc1=[a['acc1'] for a in accs], acc5=[a['acc5'] for a in accs])
  49. if split_ema:
  50. emaaccs = dict(acc1=[a for i, a in enumerate(accs['acc1']) if i % 2 == 1],
  51. acc5=[a for i, a in enumerate(accs['acc5']) if i % 2 == 1])
  52. accs = dict(acc1=[a for i, a in enumerate(accs['acc1']) if i % 2 == 0],
  53. acc5=[a for i, a in enumerate(accs['acc5']) if i % 2 == 0])
  54. x_axis = range(len(accs['acc1']))
  55. return x_axis, accs, emaaccs
  56. def get_loss_swin(f: list, x1e=torch.tensor(list(range(0, 1253, 10))).view(1, -1) / 1253, scale=1):
  57. if isinstance(f, str):
  58. f = open(f, "r").readlines()
  59. avglosses = []
  60. losses = []
  61. for i, line in enumerate(f):
  62. if "Train: [" in line and ("loss" in line):
  63. l = line.split("loss")[1].strip(" ").split(" ")[:2]
  64. losses.append(float(l[0]))
  65. avglosses.append(float(l[1].split(")")[0].strip("()")))
  66. x = x1e
  67. x = x.repeat(len(losses) // x.shape[1] + 1, 1)
  68. x = x + torch.arange(0, x.shape[0]).view(-1, 1)
  69. x = x.flatten().tolist()
  70. x_axis = x[:len(losses)]
  71. losses = [l * scale for l in losses]
  72. avglosses = [l * scale for l in avglosses]
  73. return x_axis, losses, avglosses
  74. def get_acc_mmpretrain(f: list):
  75. if isinstance(f, str):
  76. f = open(f, "r").readlines()
  77. accs = []
  78. for i, line in enumerate(f):
  79. if "accuracy_top-1" in line:
  80. line = line.split("accuracy_top-1")[1] # ": 81.182, "accuracy_top-5": 95.606}
  81. lis = line.split("accuracy_top-5") # [": 81.182, ", ": 95.606}]
  82. acc1 = float(lis[0].split(",")[0].split(" ")[-1])
  83. acc5 = float(lis[1].split("}")[0].split(" ")[-1])
  84. accs.append(dict(acc1=acc1, acc5=acc5))
  85. accs = dict(acc1=[a['acc1'] for a in accs], acc5=[a['acc5'] for a in accs])
  86. x_axis = list(range(10, 10 * len(accs['acc1']) + 1, 10))
  87. return x_axis, accs, None
  88. def get_loss_mmpretrain(f: list, x1e=torch.tensor(list(range(100, 1201, 100))).view(1, -1) / 1201, scale=1):
  89. if isinstance(f, str):
  90. f = open(f, "r").readlines()
  91. losses = []
  92. for i, line in enumerate(f):
  93. if "loss" in line:
  94. line = line.split("loss")[1].split(",")[0].split(" ")[-1] # 6.95273
  95. losses.append(float(line))
  96. x = x1e
  97. x = x.repeat(len(losses) // x.shape[1] + 1, 1)
  98. x = x + torch.arange(0, x.shape[0]).view(-1, 1)
  99. x = x.flatten().tolist()
  100. x_axis = x[:len(losses)]
  101. losses = [l * scale for l in losses]
  102. # avglosses = [l * scale for l in avglosses]
  103. return x_axis, None, losses
  104. def linefit(xaxis, yaxis, fit_range=None, out_range=None):
  105. import numpy as np
  106. if fit_range is not None:
  107. # asset xaxis increases
  108. start, end = 0, -1
  109. for i in range(len(xaxis)):
  110. if xaxis[i] <= fit_range[0] and ((i == len(xaxis) - 1) or xaxis[i + 1] > fit_range[0]):
  111. start = i
  112. if xaxis[i] < fit_range[1] and ((i == len(xaxis) - 1) or xaxis[i + 1] >= fit_range[1]):
  113. end = i
  114. if start == end:
  115. raise IndexError(f"{fit_range} out of range.")
  116. xaxis = xaxis[start: end]
  117. yaxis = yaxis[start: end]
  118. if out_range is None:
  119. out_range = fit_range
  120. outx = out_range
  121. z = np.polyfit(xaxis, yaxis, deg=1)
  122. return outx, [z[0] * _x + z[1] for _x in outx]
  123. def draw_fig(data: list, xlim=(0, 301), ylim=(68, 84), xstep=None,ystep=None, save_path="./show.jpg"):
  124. assert isinstance(data[0], dict)
  125. from matplotlib import pyplot as plot
  126. fig, ax = plot.subplots(dpi=400, figsize=(24, 8))
  127. for d in data:
  128. length = min(len(d['x']), len(d['y']))
  129. x_axis = d['x'][:length]
  130. y_axis = d['y'][:length]
  131. label = d['label']
  132. ax.plot(x_axis, y_axis, label=label)
  133. plot.xlim(xlim)
  134. plot.ylim(ylim)
  135. plot.legend()
  136. if xstep is not None:
  137. plot.xticks(torch.arange(xlim[0], xlim[1], xstep).tolist())
  138. if ystep is not None:
  139. plot.yticks(torch.arange(ylim[0], ylim[1], ystep).tolist())
  140. plot.grid()
  141. # plot.show()
  142. os.makedirs(os.path.dirname(save_path), exist_ok=True)
  143. plot.savefig(save_path)
  144. def readlog_classification(logfile):
  145. _logs = open(logfile).readlines()
  146. MAX_SEARCH = 300
  147. _epochs, _accs, _emaaccs = [], [], []
  148. for i in range(0, len(_logs)):
  149. _lr = _logs[i]
  150. if "INFO" in _lr and f"ckpt_epoch_" in _lr and f".pth saved !!!\n" in _lr:
  151. epoch = int(_lr.split("ckpt_epoch_")[1].split(".pth")[0])
  152. _acc, _emaacc = -1, -1
  153. for j in range(i + 1, min(i + MAX_SEARCH, len(_logs))):
  154. if f"INFO Max accuracy:" in _logs[j]:
  155. assert "INFO Accuracy of the network" in _logs[j-1]
  156. assert "INFO * Acc@1" in _logs[j-2]
  157. _acc = float(_logs[j-2].split("INFO * Acc@1")[1].strip().split(" ")[0].strip())
  158. if f"INFO Max accuracy ema:" in _logs[j]:
  159. assert "INFO Accuracy of the network" in _logs[j-1]
  160. assert "INFO * Acc@1" in _logs[j-2]
  161. _emaacc = float(_logs[j-2].split("INFO * Acc@1")[1].strip().split(" ")[0].strip())
  162. if f"INFO Train:" in _logs[j]:
  163. break
  164. _epochs.append(epoch)
  165. _accs.append(_acc)
  166. _emaaccs.append(_emaacc)
  167. _max_acc = np.array(_accs).max() if len(_accs) > 0 else -1
  168. _max_acc_idx = np.flatnonzero(np.array(_accs) == _max_acc)
  169. _max_emaacc = np.array(_emaaccs).max() if len(_emaaccs) > 0 else -1
  170. _max_emaacc_idx = np.flatnonzero(np.array(_emaaccs) == _max_emaacc)
  171. _mkidx = np.union1d(_max_acc_idx, _max_emaacc_idx)
  172. _final_epoch = max(_epochs)
  173. print(f"\033[4;32mmax acc ema: {_max_emaacc}, {[_epochs[i] for i in _max_emaacc_idx]}; max acc: {_max_acc}, {[_epochs[i] for i in _max_acc_idx]}; final ckpt: {_final_epoch}; \033[0m")
  174. _ckpts = [f"ckpt_epoch_{e}.pth" for e in set([_final_epoch, *[_epochs[i] for i in _mkidx]]) if e != -1]
  175. return _ckpts
  176. def readlog_mmdetection(logfile):
  177. _logs = open(logfile).readlines()
  178. _coco_bbox_mAPs, _coco_segm_mAPs, _epochs, _keylogs = [], [], [], []
  179. for i, _l in enumerate(_logs):
  180. if ("Epoch(val)" in _l) and (" eta: 0:" not in _l):
  181. epoch = int(_l.split("Epoch(val) [")[1].split("][")[0].strip())
  182. _epochs.append(epoch)
  183. if "coco/bbox_mAP" in _l:
  184. _coco_bbox_mAPs.append(float(_l.split(" coco/bbox_mAP: ")[1].strip().split(" ")[0].strip()))
  185. if "coco/segm_mAP" in _l:
  186. _coco_segm_mAPs.append(float(_l.split(" coco/segm_mAP: ")[1].strip().split(" ")[0].strip()))
  187. _keylogs.append(_l.strip("\n"))
  188. _max_coco_bbox_mAP = np.array(_coco_bbox_mAPs).max() if len(_coco_bbox_mAPs) > 0 else -1
  189. _max_bbox_idx = np.flatnonzero(_coco_bbox_mAPs == _max_coco_bbox_mAP)
  190. _max_coco_segm_mAP = np.array(_coco_segm_mAPs).max() if len(_coco_segm_mAPs) > 0 else -1
  191. _max_segm_idx = np.flatnonzero(_coco_segm_mAPs == _max_coco_segm_mAP)
  192. _mkidx = np.union1d(_max_bbox_idx, _max_segm_idx)
  193. _mkepochs = [_epochs[i] for i in _mkidx]
  194. _mklogs = [_keylogs[i] for i in _mkidx]
  195. _final_epoch = max(_epochs)
  196. print(f"\033[4;32mbboxmAP: {_max_coco_bbox_mAP}; segmmAP: {_max_coco_segm_mAP}; _mkepochs: {_mkepochs}; final epoch: {_final_epoch}; \033[0m")
  197. for l in _mklogs:
  198. perfs = " ".join([
  199. l.split("Epoch(val) ")[1].split(" ")[0],
  200. l.split("coco/bbox_mAP: ")[1].split(" ")[0],
  201. l.split("coco/bbox_mAP_50: ")[1].split(" ")[0],
  202. l.split("coco/bbox_mAP_75: ")[1].split(" ")[0],
  203. l.split("coco/segm_mAP: ")[1].split(" ")[0],
  204. l.split("coco/segm_mAP_50: ")[1].split(" ")[0],
  205. l.split("coco/segm_mAP_75: ")[1].split(" ")[0],
  206. ])
  207. print(f"\033[4;32m{perfs} \033[0m")
  208. print(_mklogs, logfile)
  209. _ckpts = [f"epoch_{e}.pth" for e in set([_final_epoch, *_mkepochs]) if e != -1]
  210. return _ckpts
  211. def readlog_mmsegmentation(logfile, test=False):
  212. _logs = open(logfile).readlines()
  213. _mious, _iters, _keylogs = [], [], []
  214. for i, _l in enumerate(_logs):
  215. if test:
  216. if ("INFO - Iter(test) [" in _l) and (" eta: " not in _l):
  217. _iters.append(-1)
  218. _mious.append(float(_l.split(" mIoU:")[1].strip().split(" ")[0].strip()))
  219. _keylogs.append(_l.strip("\n"))
  220. else:
  221. if ("INFO - Iter(val) [" in _l) and (" eta: " not in _l):
  222. _iter = None
  223. for j in range(i, -1000, -1):
  224. if "Saving checkpoint at" in _logs[j]:
  225. _iter = int(_logs[j].split("Saving checkpoint at ")[1].strip().split(" ")[0].strip())
  226. break
  227. assert isinstance(_iter, int), "ERROR: can not find iter"
  228. _iters.append(_iter)
  229. _mious.append(float(_l.split(" mIoU:")[1].strip().split(" ")[0].strip()))
  230. _keylogs.append(_l.strip("\n"))
  231. _max_miou = np.array(_mious).max() if len(_mious) > 0 else -1
  232. _max_miou_idx = np.flatnonzero(_mious == _max_miou)
  233. _mkidx = _max_miou_idx
  234. _mkiters = [_iters[i] for i in _mkidx]
  235. _mklogs = [_keylogs[i] for i in _mkidx]
  236. _final_iter = max(_iters)
  237. print(f"\033[4;32mmiou: {_max_miou}; _mkiters: {_mkiters}; final iter: {_final_iter}; \033[0m")
  238. print(_mklogs, logfile)
  239. _ckpts = [f"iter_{e}.pth" for e in set([_final_iter, *_mkiters]) if e != -1]
  240. return _ckpts
  241. def cpclassification(src, name, dstpath="", fake_copy=False, update=False, onlylog=False):
  242. dst = os.path.join(dstpath, name)
  243. os.makedirs(dst, exist_ok=True)
  244. print(f"\033[4;32m{name} =======================================\033[0m")
  245. for file in ["config.json", "log_rank0.txt"]:
  246. if os.path.exists(os.path.join(dst, file)):
  247. print(f"WARNING: file [{os.path.join(dst, file)}] exist already")
  248. if not update:
  249. continue
  250. _s = os.path.join(src, file)
  251. assert os.path.exists(dst) and os.path.isdir(dst)
  252. assert os.path.exists(_s), f"Not found: {_s}"
  253. print(f"copy from [{_s}] to [{dst}]")
  254. if not fake_copy:
  255. shutil.copy(_s, dst)
  256. _ckpts = readlog_classification(os.path.join(os.path.abspath(dst), "log_rank0.txt"))
  257. if not onlylog:
  258. for file in _ckpts:
  259. _s = os.path.join(src, file)
  260. assert os.path.exists(_s)
  261. if os.path.exists(os.path.join(dst, file)):
  262. print(f"WARNING: file [{os.path.join(dst, file)}] exist already")
  263. with open(os.path.join(dst, file), "rb") as f:
  264. torch.load(f)
  265. else:
  266. assert os.path.exists(dst) and os.path.isdir(dst)
  267. print(f"copy from [{_s}] to [{dst}]")
  268. if not fake_copy:
  269. shutil.copy(_s, dst)
  270. def puremodel(ickptfile=".", opath=".", key="model", convert_key="model", name="vssmtmp"):
  271. ckptname = os.path.basename(ickptfile)
  272. ilogfile = os.path.join(os.path.dirname(ickptfile), "log_rank0.txt")
  273. opath = os.path.join(opath, name)
  274. ockptfile = os.path.join(opath, f"{name}_{ckptname}")
  275. ologfile = os.path.join(opath, f"{name}.txt")
  276. os.makedirs(opath, exist_ok=True)
  277. print(f"{name} =======================================")
  278. _ckpts = readlog_classification(ilogfile)
  279. _ckpt = torch.load(open(ickptfile, "rb"), map_location=torch.device("cpu"))
  280. if key not in _ckpt.keys():
  281. raise KeyError(f"key {key} not in ckpt.keys: {_ckpt.keys()}")
  282. _ockpt = {convert_key: _ckpt[key]}
  283. if os.path.exists(ockptfile):
  284. print(f"WARNING file {ockptfile} exists.")
  285. else:
  286. torch.save(_ockpt, open(ockptfile, "wb"))
  287. print(f"{ockptfile} saved...")
  288. assert os.path.exists(ilogfile), f"log file {ilogfile} not found"
  289. if os.path.exists(ologfile):
  290. print(f"WARNING file {ologfile} exists.")
  291. else:
  292. shutil.copy(ilogfile, ologfile)
  293. print(f"{ologfile} saved...")
  294. def puremodelmmdet(ilogfile, opath=".", fake_copy=False, mode="coco"):
  295. ilogfile = os.path.abspath(ilogfile)
  296. ilogfiledir = os.path.dirname(ilogfile)
  297. ipath = os.path.dirname(ilogfiledir)
  298. name = os.path.basename(ipath)
  299. configfile = os.path.join(ipath, f"{name}.py")
  300. assert os.path.exists(configfile), f"can not process directory: {os.listdir(ipath)}"
  301. dst = os.path.join(opath, name)
  302. ologfile = os.path.join(dst, f"{name}.log")
  303. os.makedirs(dst, exist_ok=True)
  304. print(f"\033[4;32m{name} =======================================\033[0m")
  305. for _s in [ilogfiledir, configfile]:
  306. _o = os.path.join(dst, os.path.basename(_s))
  307. if os.path.exists(_o):
  308. print(f"WARNING: file [{_o}] exist already")
  309. else:
  310. assert os.path.exists(dst) and os.path.isdir(dst)
  311. print(f"copy from [{_s}] to [{dst}]")
  312. if not fake_copy:
  313. shutil.copytree(_s, _o) if os.path.isdir(_s) else shutil.copy(_s, dst)
  314. assert os.path.exists(ilogfile), f"log file {ilogfile} not found"
  315. if os.path.exists(ologfile):
  316. print(f"WARNING file {ologfile} exists.")
  317. else:
  318. shutil.copy(ilogfile, ologfile)
  319. print(f"{ologfile} saved...")
  320. if mode in ["coco"]:
  321. _ckpts = readlog_mmdetection(ilogfile)
  322. elif mode in ["ade20k"]:
  323. _ckpts = readlog_mmsegmentation(ilogfile)
  324. for _s in _ckpts:
  325. ickptfile = os.path.join(ipath, _s)
  326. ockptfile = os.path.join(dst, f"{name}_{_s}")
  327. _ckpt = torch.load(open(ickptfile, "rb"), map_location=torch.device("cpu"))
  328. _ockpt = {"meta": _ckpt["meta"], "state_dict": _ckpt["state_dict"]}
  329. if os.path.exists(ockptfile):
  330. print(f"WARNING file {ockptfile} exists.")
  331. else:
  332. torch.save(_ockpt, open(ockptfile, "wb"))
  333. print(f"{ockptfile} saved...")
  334. def main_vssm():
  335. results = {}
  336. showpath = os.path.join(os.path.dirname(__file__), "./show/vssm1tifig")
  337. files = dict(
  338. )
  339. for name, file in files.items():
  340. x, accs, emaaccs = get_acc_swin(file, split_ema=True)
  341. lx, losses, avglosses = get_loss_swin(file, x1e=torch.tensor(list(range(0, 1251, 10))).view(1, -1) / 1251, scale=1)
  342. file = dict(xaxis=x, accs=accs, emaaccs=emaaccs, loss_xaxis=lx, losses=losses, avglosses=avglosses)
  343. results.update({name: file})
  344. draw_fig(data=[
  345. *[dict(x=file['xaxis'], y=file['accs']['acc1'], label=name) for name, file in results.items()],
  346. *[dict(x=file['xaxis'], y=file['emaaccs']['acc1'], label=f"{name}_ema") for name, file in results.items()],
  347. ], xlim=(30, 300), ylim=(70, 85), xstep=5, ystep=0.5, save_path=f"{showpath}/acc.jpg")
  348. draw_fig(data=[
  349. *[dict(x=file['loss_xaxis'], y=file['avglosses'], label=name) for name, file in results.items()],
  350. ], xlim=(10, 300), ylim=(2,5), save_path=f"{showpath}/loss.jpg")
  351. if __name__ == "__main__":
  352. main_vssm()