| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- from __future__ import annotations
- import argparse
- import csv
- import json
- import sys
- from collections import defaultdict
- from pathlib import Path
- from typing import Any
- from swanlab.data.porter import DataPorter
- def parse_args() -> argparse.Namespace:
- parser = argparse.ArgumentParser(
- description="Export SwanLab backup.swanlab records to readable CSV/JSONL files."
- )
- parser.add_argument(
- "run_dir",
- help="SwanLab run directory, e.g. swanlog/run-20260530_115103-...",
- )
- parser.add_argument(
- "--out-dir",
- default=None,
- help="Output directory. Defaults to <run_dir>/exported.",
- )
- parser.add_argument(
- "--exclude-system",
- action="store_true",
- help="Exclude SwanLab system metrics whose keys start with __swanlab__.",
- )
- return parser.parse_args()
- def scalar_to_row(scalar: Any) -> dict[str, Any]:
- metric = scalar.metric or {}
- return {
- "key": scalar.key,
- "step": scalar.step,
- "epoch": scalar.epoch,
- "index": metric.get("index"),
- "data": metric.get("data"),
- "create_time": metric.get("create_time"),
- }
- def log_to_row(log: Any) -> dict[str, Any]:
- return {
- "level": log.level,
- "message": log.message,
- "create_time": log.create_time,
- "epoch": log.epoch,
- }
- def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None:
- with path.open("w", encoding="utf-8") as handle:
- for row in rows:
- handle.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")
- def write_csv(path: Path, rows: list[dict[str, Any]], fieldnames: list[str]) -> None:
- with path.open("w", encoding="utf-8", newline="") as handle:
- writer = csv.DictWriter(handle, fieldnames=fieldnames)
- writer.writeheader()
- writer.writerows(rows)
- def build_epoch_table(scalar_rows: list[dict[str, Any]]) -> list[dict[str, Any]]:
- grouped: dict[int, dict[str, Any]] = defaultdict(dict)
- for row in scalar_rows:
- epoch = row.get("step")
- key = row.get("key")
- if epoch is None or key is None:
- continue
- if str(key).startswith("__swanlab__"):
- continue
- grouped[int(epoch)]["epoch"] = int(epoch)
- grouped[int(epoch)][str(key)] = row.get("data")
- return [grouped[epoch] for epoch in sorted(grouped)]
- def main() -> None:
- args = parse_args()
- run_dir = Path(args.run_dir)
- if not (run_dir / "backup.swanlab").exists():
- raise FileNotFoundError(f"backup.swanlab not found under {run_dir}")
- out_dir = Path(args.out_dir) if args.out_dir is not None else run_dir / "exported"
- out_dir.mkdir(parents=True, exist_ok=True)
- with DataPorter().open_for_sync(str(run_dir), backend="python") as porter:
- project, experiment = porter.parse()
- scalar_rows = [scalar_to_row(scalar) for scalar in porter._scalars]
- log_rows = [log_to_row(log) for log in porter._logs]
- if args.exclude_system:
- scalar_rows = [
- row for row in scalar_rows if not str(row["key"]).startswith("__swanlab__")
- ]
- scalar_fields = ["key", "step", "epoch", "index", "data", "create_time"]
- log_fields = ["level", "message", "create_time", "epoch"]
- write_csv(out_dir / "scalars.csv", scalar_rows, scalar_fields)
- write_jsonl(out_dir / "scalars.jsonl", scalar_rows)
- write_csv(out_dir / "logs.csv", log_rows, log_fields)
- write_jsonl(out_dir / "logs.jsonl", log_rows)
- epoch_rows = build_epoch_table(scalar_rows)
- if epoch_rows:
- fields = ["epoch"]
- for row in epoch_rows:
- for key in row:
- if key not in fields:
- fields.append(key)
- write_csv(out_dir / "epoch_metrics.csv", epoch_rows, fields)
- write_jsonl(out_dir / "epoch_metrics.jsonl", epoch_rows)
- summary = {
- "project": project.name,
- "experiment_id": experiment.id,
- "experiment_name": experiment.name,
- "num_scalars": len(scalar_rows),
- "num_logs": len(log_rows),
- "num_epoch_rows": len(epoch_rows),
- "out_dir": str(out_dir),
- }
- print(summary)
- if __name__ == "__main__":
- try:
- main()
- except Exception as exc:
- print(f"error: {type(exc).__name__}: {exc}", file=sys.stderr)
- raise
|