export_swanlab_backup.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. from __future__ import annotations
  2. import argparse
  3. import csv
  4. import json
  5. import sys
  6. from collections import defaultdict
  7. from pathlib import Path
  8. from typing import Any
  9. from swanlab.data.porter import DataPorter
  10. def parse_args() -> argparse.Namespace:
  11. parser = argparse.ArgumentParser(
  12. description="Export SwanLab backup.swanlab records to readable CSV/JSONL files."
  13. )
  14. parser.add_argument(
  15. "run_dir",
  16. help="SwanLab run directory, e.g. swanlog/run-20260530_115103-...",
  17. )
  18. parser.add_argument(
  19. "--out-dir",
  20. default=None,
  21. help="Output directory. Defaults to <run_dir>/exported.",
  22. )
  23. parser.add_argument(
  24. "--exclude-system",
  25. action="store_true",
  26. help="Exclude SwanLab system metrics whose keys start with __swanlab__.",
  27. )
  28. return parser.parse_args()
  29. def scalar_to_row(scalar: Any) -> dict[str, Any]:
  30. metric = scalar.metric or {}
  31. return {
  32. "key": scalar.key,
  33. "step": scalar.step,
  34. "epoch": scalar.epoch,
  35. "index": metric.get("index"),
  36. "data": metric.get("data"),
  37. "create_time": metric.get("create_time"),
  38. }
  39. def log_to_row(log: Any) -> dict[str, Any]:
  40. return {
  41. "level": log.level,
  42. "message": log.message,
  43. "create_time": log.create_time,
  44. "epoch": log.epoch,
  45. }
  46. def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
  47. with path.open("w", encoding="utf-8") as handle:
  48. for row in rows:
  49. handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
  50. def write_csv(path: Path, rows: list[dict[str, Any]], fieldnames: list[str]) -> None:
  51. with path.open("w", encoding="utf-8", newline="") as handle:
  52. writer = csv.DictWriter(handle, fieldnames=fieldnames)
  53. writer.writeheader()
  54. writer.writerows(rows)
  55. def build_epoch_table(scalar_rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
  56. grouped: dict[int, dict[str, Any]] = defaultdict(dict)
  57. for row in scalar_rows:
  58. epoch = row.get("step")
  59. key = row.get("key")
  60. if epoch is None or key is None:
  61. continue
  62. if str(key).startswith("__swanlab__"):
  63. continue
  64. grouped[int(epoch)]["epoch"] = int(epoch)
  65. grouped[int(epoch)][str(key)] = row.get("data")
  66. return [grouped[epoch] for epoch in sorted(grouped)]
  67. def main() -> None:
  68. args = parse_args()
  69. run_dir = Path(args.run_dir)
  70. if not (run_dir / "backup.swanlab").exists():
  71. raise FileNotFoundError(f"backup.swanlab not found under {run_dir}")
  72. out_dir = Path(args.out_dir) if args.out_dir is not None else run_dir / "exported"
  73. out_dir.mkdir(parents=True, exist_ok=True)
  74. with DataPorter().open_for_sync(str(run_dir), backend="python") as porter:
  75. project, experiment = porter.parse()
  76. scalar_rows = [scalar_to_row(scalar) for scalar in porter._scalars]
  77. log_rows = [log_to_row(log) for log in porter._logs]
  78. if args.exclude_system:
  79. scalar_rows = [
  80. row for row in scalar_rows if not str(row["key"]).startswith("__swanlab__")
  81. ]
  82. scalar_fields = ["key", "step", "epoch", "index", "data", "create_time"]
  83. log_fields = ["level", "message", "create_time", "epoch"]
  84. write_csv(out_dir / "scalars.csv", scalar_rows, scalar_fields)
  85. write_jsonl(out_dir / "scalars.jsonl", scalar_rows)
  86. write_csv(out_dir / "logs.csv", log_rows, log_fields)
  87. write_jsonl(out_dir / "logs.jsonl", log_rows)
  88. epoch_rows = build_epoch_table(scalar_rows)
  89. if epoch_rows:
  90. fields = ["epoch"]
  91. for row in epoch_rows:
  92. for key in row:
  93. if key not in fields:
  94. fields.append(key)
  95. write_csv(out_dir / "epoch_metrics.csv", epoch_rows, fields)
  96. write_jsonl(out_dir / "epoch_metrics.jsonl", epoch_rows)
  97. summary = {
  98. "project": project.name,
  99. "experiment_id": experiment.id,
  100. "experiment_name": experiment.name,
  101. "num_scalars": len(scalar_rows),
  102. "num_logs": len(log_rows),
  103. "num_epoch_rows": len(epoch_rows),
  104. "out_dir": str(out_dir),
  105. }
  106. print(summary)
  107. if __name__ == "__main__":
  108. try:
  109. main()
  110. except Exception as exc:
  111. print(f"error: {type(exc).__name__}: {exc}", file=sys.stderr)
  112. raise