analyze_logs.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. """Modified from https://github.com/open-
  3. mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
  4. import argparse
  5. import json
  6. from collections import defaultdict
  7. import matplotlib.pyplot as plt
  8. import seaborn as sns
  9. def plot_curve(log_dicts, args):
  10. if args.backend is not None:
  11. plt.switch_backend(args.backend)
  12. sns.set_style(args.style)
  13. # if legend is None, use {filename}_{key} as legend
  14. legend = args.legend
  15. if legend is None:
  16. legend = []
  17. for json_log in args.json_logs:
  18. for metric in args.keys:
  19. legend.append(f'{json_log}_{metric}')
  20. assert len(legend) == (len(args.json_logs) * len(args.keys))
  21. metrics = args.keys
  22. num_metrics = len(metrics)
  23. for i, log_dict in enumerate(log_dicts):
  24. epochs = list(log_dict.keys())
  25. for j, metric in enumerate(metrics):
  26. print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
  27. plot_epochs = []
  28. plot_iters = []
  29. plot_values = []
  30. # In some log files exist lines of validation,
  31. # `mode` list is used to only collect iter number
  32. # of training line.
  33. for epoch in epochs:
  34. epoch_logs = log_dict[epoch]
  35. if metric not in epoch_logs.keys():
  36. continue
  37. if metric in ['mIoU', 'mAcc', 'aAcc']:
  38. plot_epochs.append(epoch)
  39. plot_values.append(epoch_logs[metric][0])
  40. else:
  41. for idx in range(len(epoch_logs[metric])):
  42. plot_iters.append(epoch_logs['step'][idx])
  43. plot_values.append(epoch_logs[metric][idx])
  44. ax = plt.gca()
  45. label = legend[i * num_metrics + j]
  46. if metric in ['mIoU', 'mAcc', 'aAcc']:
  47. ax.set_xticks(plot_epochs)
  48. plt.xlabel('step')
  49. plt.plot(plot_epochs, plot_values, label=label, marker='o')
  50. else:
  51. plt.xlabel('iter')
  52. plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
  53. plt.legend()
  54. if args.title is not None:
  55. plt.title(args.title)
  56. if args.out is None:
  57. plt.show()
  58. else:
  59. print(f'save curve to: {args.out}')
  60. plt.savefig(args.out)
  61. plt.cla()
  62. def parse_args():
  63. parser = argparse.ArgumentParser(description='Analyze Json Log')
  64. parser.add_argument(
  65. 'json_logs',
  66. type=str,
  67. nargs='+',
  68. help='path of train log in json format')
  69. parser.add_argument(
  70. '--keys',
  71. type=str,
  72. nargs='+',
  73. default=['mIoU'],
  74. help='the metric that you want to plot')
  75. parser.add_argument('--title', type=str, help='title of figure')
  76. parser.add_argument(
  77. '--legend',
  78. type=str,
  79. nargs='+',
  80. default=None,
  81. help='legend of each plot')
  82. parser.add_argument(
  83. '--backend', type=str, default=None, help='backend of plt')
  84. parser.add_argument(
  85. '--style', type=str, default='dark', help='style of plt')
  86. parser.add_argument('--out', type=str, default=None)
  87. args = parser.parse_args()
  88. return args
  89. def load_json_logs(json_logs):
  90. # load and convert json_logs to log_dict, key is step, value is a sub dict
  91. # keys of sub dict is different metrics
  92. # value of sub dict is a list of corresponding values of all iterations
  93. log_dicts = [dict() for _ in json_logs]
  94. prev_step = 0
  95. for json_log, log_dict in zip(json_logs, log_dicts):
  96. with open(json_log) as log_file:
  97. for line in log_file:
  98. log = json.loads(line.strip())
  99. # the final step in json file is 0.
  100. if 'step' in log and log['step'] != 0:
  101. step = log['step']
  102. prev_step = step
  103. else:
  104. step = prev_step
  105. if step not in log_dict:
  106. log_dict[step] = defaultdict(list)
  107. for k, v in log.items():
  108. log_dict[step][k].append(v)
  109. return log_dicts
  110. def main():
  111. args = parse_args()
  112. json_logs = args.json_logs
  113. for json_log in json_logs:
  114. assert json_log.endswith('.json')
  115. log_dicts = load_json_logs(json_logs)
  116. plot_curve(log_dicts, args)
  117. if __name__ == '__main__':
  118. main()